diff --git a/.bazelrc b/.bazelrc index ceba7bfdbac74d1e44aadc3010e5e84bd36ce3ee..c70c57136102b483a4332ca22f775d7a2c5b849e 100644 --- a/.bazelrc +++ b/.bazelrc @@ -25,12 +25,14 @@ build --define framework_shared_object=true # If you would like to use a local MKL instead of downloading, please set the # environment variable "TF_MKL_ROOT" every time before build. build:mkl --define=build_with_mkl=true --define=enable_mkl=true +build:mkl --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl -c opt # This config option is used to enable MKL-DNN open source library only, # without depending on MKL binary version. build:mkl_open_source_only --define=build_with_mkl_dnn_only=true build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true +build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0 build:download_clang --crosstool_top=@local_config_download_clang//:toolchain build:download_clang --define=using_clang=true @@ -78,7 +80,7 @@ build --define=use_fast_cpp_protos=true build --define=allow_oversize_protos=true build --spawn_strategy=standalone -build --genrule_strategy=standalone +build --strategy=Genrule=standalone build -c opt # Other build flags. @@ -93,9 +95,6 @@ build --define=PREFIX=/usr build --define=LIBDIR=$(PREFIX)/lib build --define=INCLUDEDIR=$(PREFIX)/include -# Disable MKL-DNN contraction kernels by default. -build --define=tensorflow_mkldnn_contraction_kernel=0 - # Default options should come above this line # Options from ./configure diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4a296f265f7b9521c46d350cec26ff199f43eb6c..b978f89f9e1d79dd4f7481711a59c2b94e8bf01b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -150,41 +150,45 @@ may exist in your changes. There are two ways to run TensorFlow unit tests. -1. Using tools and libraries installed directly on your system. +1. Using tools and libraries installed directly on your system. - Refer to the - [CPU-only developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel) and - [GPU developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel-gpu) - for the required packages. Alternatively, use the said - [Docker images](https://hub.docker.com/r/tensorflow/tensorflow/tags/), e.g., - `tensorflow/tensorflow:nightly-devel` and `tensorflow/tensorflow:nightly-devel-gpu` - for development to avoid installing the packages directly on your system. + Refer to the + [CPU-only developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel) + and + [GPU developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel-gpu) + for the required packages. Alternatively, use the said + [Docker images](https://hub.docker.com/r/tensorflow/tensorflow/tags/), e.g., + `tensorflow/tensorflow:nightly-devel` and + `tensorflow/tensorflow:nightly-devel-gpu` for development to avoid + installing the packages directly on your system (in which case remember to + change directory from `/root` to `/tensorflow` once you get into the running + container so `bazel` can find the `tensorflow` workspace). - Once you have the packages installed, you can run a specific unit test in - bazel by doing as follows: + Once you have the packages installed, you can run a specific unit test in + bazel by doing as follows: - If the tests are to be run on GPU, add CUDA paths to LD_LIBRARY_PATH and add - the `cuda` option flag + If the tests are to be run on GPU, add CUDA paths to LD_LIBRARY_PATH and add + the `cuda` option flag - ```bash - export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" + ```bash + export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" - export flags="--config=opt --config=cuda -k" - ``` + export flags="--config=opt --config=cuda -k" + ``` - For example, to run all tests under tensorflow/python, do: + For example, to run all tests under tensorflow/python, do: - ```bash - bazel test ${flags} //tensorflow/python/... - ``` + ```bash + bazel test ${flags} //tensorflow/python/... + ``` -2. Using [Docker](https://www.docker.com) and TensorFlow's CI scripts. +2. Using [Docker](https://www.docker.com) and TensorFlow's CI scripts. - ```bash - # Install Docker first, then this will build and run cpu tests - tensorflow/tools/ci_build/ci_build.sh CPU bazel test //tensorflow/... - ``` - - See - [TensorFlow Builds](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/ci_build) for details. + ```bash + # Install Docker first, then this will build and run cpu tests + tensorflow/tools/ci_build/ci_build.sh CPU bazel test //tensorflow/... + ``` + See + [TensorFlow Builds](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/ci_build) + for details. diff --git a/README.md b/README.md index 519815d006cc33be10132909baf414a4bd843435..4e37b239b16e6eeefc587aeb242a03e1f88eddbd 100644 --- a/README.md +++ b/README.md @@ -57,21 +57,24 @@ Simply run `pip install tf-nightly` or `pip install tf-nightly-gpu` in a clean environment to install the nightly TensorFlow build. We support CPU and GPU packages on Linux, Mac, and Windows. - #### *Try your first TensorFlow program* + ```shell $ python ``` + ```python >>> import tensorflow as tf >>> tf.enable_eager_execution() ->>> tf.add(1, 2) +>>> tf.add(1, 2).numpy() 3 >>> hello = tf.constant('Hello, TensorFlow!') >>> hello.numpy() 'Hello, TensorFlow!' ``` -Learn more examples about how to do specific tasks in TensorFlow at the [tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/). + +Learn more examples about how to do specific tasks in TensorFlow at the +[tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/). ## Contribution guidelines diff --git a/RELEASE.md b/RELEASE.md index 282430d12303bde980e19e3c3602eb91b1a54d63..0a56e6909870e398c9d6349576cd2f8e6734f072 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -849,7 +849,7 @@ answered questions, and were part of inspiring discussions. * Remove `tf.contrib.data.Iterator.from_dataset()` method. Use `Dataset.make_initializable_iterator()` instead. * Remove seldom used and unnecessary `tf.contrib.data.Iterator.dispose_op()`. -* Reorder some TFGAN loss functions in a non-backwards compatible way. +* Reorder some TF-GAN loss functions in a non-backwards compatible way. ## Known Issues * In Python 3, `Dataset.from_generator()` does not support Unicode strings. diff --git a/WORKSPACE b/WORKSPACE index 2277e83a3f67b62cf4ee1311767ee06c0549c697..957b8d8528dc9b5e2ea134921b28601aa6fed2d1 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -4,11 +4,11 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file" http_archive( name = "io_bazel_rules_closure", - sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae", - strip_prefix = "rules_closure-dbb96841cc0a5fb2664c37822803b06dab20c7d1", + sha256 = "43c9b882fa921923bcba764453f4058d102bece35a37c9f6383c713004aacff1", + strip_prefix = "rules_closure-9889e2348259a5aad7e805547c1a0cf311cfcd91", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", - "https://github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", # 2018-04-13 + "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/9889e2348259a5aad7e805547c1a0cf311cfcd91.tar.gz", + "https://github.com/bazelbuild/rules_closure/archive/9889e2348259a5aad7e805547c1a0cf311cfcd91.tar.gz", # 2018-12-21 ], ) @@ -73,7 +73,7 @@ swift_rules_dependencies() # files, in case the parsing of those build files depends on the bazel # version we require here. load("//tensorflow:version_check.bzl", "check_bazel_version_at_least") -check_bazel_version_at_least("0.18.0") +check_bazel_version_at_least("0.19.0") load("//tensorflow:workspace.bzl", "tf_workspace") diff --git a/configure.py b/configure.py index 1e732db26404906901a9eeab97a5e75137ee8388..adc9ef9caca8c0128c63896fdebbbadf7f86da81 100644 --- a/configure.py +++ b/configure.py @@ -33,7 +33,7 @@ except ImportError: from distutils.spawn import find_executable as which # pylint: enable=g-import-not-at-top -_DEFAULT_CUDA_VERSION = '9.0' +_DEFAULT_CUDA_VERSION = '10.0' _DEFAULT_CUDNN_VERSION = '7' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0' _DEFAULT_CUDA_PATH = '/usr/local/cuda' @@ -480,7 +480,9 @@ def check_bazel_version(min_version, max_version): if (curr_version_int > max_version_int and 'TF_IGNORE_MAX_BAZEL_VERSION' not in os.environ): print('Please downgrade your bazel installation to version %s or lower to ' - 'build TensorFlow!' % max_version) + 'build TensorFlow! To downgrade: download the installer for the old ' + 'version (from https://github.com/bazelbuild/bazel/releases) then ' + 'run the installer.' % max_version) sys.exit(1) return curr_version @@ -1554,7 +1556,7 @@ def main(): # environment variables. environ_cp = dict(os.environ) - check_bazel_version('0.19.0', '0.20.0') + check_bazel_version('0.19.0', '0.21.0') reset_tf_configure_bazelrc() @@ -1692,7 +1694,7 @@ def main(): config_info_line('noaws', 'Disable AWS S3 filesystem support.') config_info_line('nogcp', 'Disable GCP support.') config_info_line('nohdfs', 'Disable HDFS support.') - config_info_line('noignite', 'Disable Apacha Ignite support.') + config_info_line('noignite', 'Disable Apache Ignite support.') config_info_line('nokafka', 'Disable Apache Kafka support.') config_info_line('nonccl', 'Disable NVIDIA NCCL support.') diff --git a/tensorflow/BUILD b/tensorflow/BUILD index f07e7365d3482cde5b7bb76ebf22890150e98651..413806fac14ca4605606507726d7ff87ce73a699 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -343,6 +343,13 @@ config_setting( }, ) +config_setting( + name = "using_rocm_hipcc", + define_values = { + "using_rocm_hipcc": "true", + }, +) + config_setting( name = "with_mpi_support", values = {"define": "with_mpi_support=true"}, @@ -370,13 +377,22 @@ config_setting( define_values = {"tf_api_version": "2"}, ) +# This flag is defined for select statements that match both +# on 'windows' and 'api_version_2'. In this case, bazel requires +# having a flag which is a superset of these two. +config_setting( + name = "windows_and_api_version_2", + define_values = {"tf_api_version": "2"}, + values = {"cpu": "x64_windows"}, +) + package_group( name = "internal", packages = [ "-//third_party/tensorflow/python/estimator", "//learning/deepmind/...", "//learning/meta_rank/...", - "//learning/pathways/...", # While dataset C++ api requires internals + "//platforms/performance/autograppler/...", "//tensorflow/...", "//tensorflow_estimator/contrib/...", "//tensorflow_fold/llgtm/...", diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 2c0a7452692e5cdb184f7f0a77eb1b646a1772d4..a93799bfe84b0f9c4743e1ad0effd6e69ad7f3f2 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -52,7 +52,7 @@ elif _tf_api_dir not in __path__: __path__.append(_tf_api_dir) # Enable TF2 behaviors -from tensorflow.python.compat import compat as _compat # pylint: disable=g-import-not-at-top +from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top _compat.enable_v2_behavior() diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 514aba1b59631f882523396aab0f4d3d5e88a893..eeca8f0d566a6401cb64e4fe3f0ee3c5aeb4ece2 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -62,11 +62,15 @@ if '__all__' in vars(): vars()['__all__'].append('contrib') from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top +# The 'app' module will be imported as part of the placeholder section above. app.flags = flags # pylint: disable=undefined-variable +# Also use 'app' module (choice is arbitrary) to derive the API directory below. +_API_MODULE = app # pylint: disable=undefined-variable + # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. -_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable +_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) # pylint: disable=undefined-variable if not hasattr(_current_module, '__path__'): __path__ = [_tf_api_dir] elif _tf_api_dir not in __path__: diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 3e1f220db233001ba652120657631f8c1a296b35..6e50a09bfc5ed3a8f2f7e05e6a6a151525e8dfce 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -83,7 +83,7 @@ tf_cuda_library( ], "//conditions:default": [ ":c_api_internal", - "//tensorflow/cc/saved_model:loader", + "//tensorflow/cc/saved_model:loader_lite", "//tensorflow/cc:gradients", "//tensorflow/cc:ops", "//tensorflow/cc:grad_ops", @@ -129,6 +129,7 @@ tf_cuda_library( "//tensorflow/core:lib_platform", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/eager:attr_builder", + "@com_google_absl//absl/strings", ], ) @@ -252,12 +253,6 @@ tf_cc_test( name = "c_test", srcs = ["c_test.c"], extra_copts = ["-std=c11"], - tags = [ - # TODO(b/121223209): Re-enable after fixing asan memory leaks and MacOS - # build errors. - "noasan", - "no_mac", - ], deps = [ ":c_api", ":c_api_experimental", @@ -294,13 +289,20 @@ tf_cuda_cc_test( "//tensorflow/cc/saved_model:signature_constants", "//tensorflow/cc/saved_model:tag_constants", "//tensorflow/compiler/jit", + "//tensorflow/core:array_ops_op_lib", + "//tensorflow/core:bitwise_ops_op_lib", + "//tensorflow/core:control_flow_ops_op_lib", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:direct_session", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:spectral_ops_op_lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/kernels:array", @@ -324,6 +326,7 @@ tf_cc_test( deps = [ ":c_api", ":c_api_experimental", + ":c_api_internal", ":c_test_util", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_test_util", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 9580215a317b1a6b1cdacbd430a1764af61be990..94d9f4a6fa2f14cb3343bdd51b7e4d61944444d0 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -257,6 +257,74 @@ int64_t TF_Dim(const TF_Tensor* t, int dim_index) { size_t TF_TensorByteSize(const TF_Tensor* t) { return t->buffer->size(); } void* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); } +int64_t TF_TensorElementCount(const TF_Tensor* t) { + int64_t result = 1; + int rank = TF_NumDims(t); + for (int dim = 0; dim < rank; ++dim) { + result *= TF_Dim(t, dim); + } + return result; +} + +// Returns the number of elements that would be present in a tensor with the +// given shape. +static int64_t ShapeNumElements(const int64_t* dims, int num_dims) { + int64_t result = 1; + for (int dim = 0; dim < num_dims; ++dim) { + result *= dims[dim]; + } + return result; +} + +static void UnrefIfNonNull(::tensorflow::TensorBuffer* buf) { + if (buf != nullptr) { + buf->Unref(); + } +} + +static void RefIfNonNull(::tensorflow::TensorBuffer* buf) { + if (buf != nullptr) { + buf->Ref(); + } +} + +void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type, + TF_Tensor* to, const int64_t* new_dims, + int num_new_dims, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + size_t in_size = TF_DataTypeSize(TF_TensorType(from)); + if (in_size == 0) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + "input tensor has a zero-sized data type"); + return; + } + size_t out_size = TF_DataTypeSize(type); + if (out_size == 0) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + "output tensor has a zero-sized data type"); + return; + } + + if (ShapeNumElements(new_dims, num_new_dims) * out_size != + TF_TensorElementCount(from) * in_size) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + "input tensor is not compatible with output shape"); + return; + } + + tensorflow::TensorShapeProto p; + for (int i = 0; i < num_new_dims; ++i) { + p.add_dim()->set_size(new_dims[i]); + } + to->shape = tensorflow::TensorShape(p); + to->dtype = type; + if (to->buffer != from->buffer) { + UnrefIfNonNull(to->buffer); + to->buffer = from->buffer; + RefIfNonNull(to->buffer); + } +} + // -------------------------------------------------------------------------- size_t TF_StringEncode(const char* src, size_t src_len, char* dst, size_t dst_len, TF_Status* status) { @@ -2881,6 +2949,9 @@ const char* TF_ServerTarget(TF_Server* server) { #endif } -void TF_DeleteServer(TF_Server* server) { delete server; } - +void TF_DeleteServer(TF_Server* server) { +#ifndef __ANDROID__ + delete server; +#endif +} } // end extern "C" diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index c7abba85521fccec07983cd5ab4f94a8368d6181..8031928dac4de2391f0aec46e69d61a137606e4d 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -272,6 +272,39 @@ TF_CAPI_EXPORT extern size_t TF_TensorByteSize(const TF_Tensor*); // Return a pointer to the underlying data buffer. TF_CAPI_EXPORT extern void* TF_TensorData(const TF_Tensor*); +// Returns the number of elements in the tensor. +TF_CAPI_EXPORT extern int64_t TF_TensorElementCount(const TF_Tensor* tensor); + +// Copy the internal data representation of `from` to `to`. `new_dims` and +// `num_new_dims` specify the new shape of the `to` tensor, `type` specifies its +// data type. On success, *status is set to TF_OK and the two tensors share the +// same data buffer. +// +// This call requires that the `from` tensor and the given type and shape (dims +// and num_dims) are "compatible" (i.e. they occupy the same number of bytes). +// Specifically, given from_type_size = TF_DataTypeSize(TF_TensorType(from)): +// +// ShapeElementCount(dims, num_dims) * TF_DataTypeSize(type) +// +// must equal +// +// TF_TensorElementCount(from) * from_type_size +// +// where TF_ShapeElementCount would be the number of elements in a tensor with +// the given shape. +// +// In addition, this function requires: +// * TF_DataTypeSize(TF_TensorType(from)) != 0 +// * TF_DataTypeSize(type) != 0 +// +// If any of the requirements are not met, *status is set to +// TF_INVALID_ARGUMENT. +TF_CAPI_EXPORT extern void TF_TensorBitcastFrom(const TF_Tensor* from, + TF_DataType type, TF_Tensor* to, + const int64_t* new_dims, + int num_new_dims, + TF_Status* status); + // -------------------------------------------------------------------------- // Encode the string `src` (`src_len` bytes long) into `dst` in the format // required by TF_STRING tensors. Does not write to memory more than `dst_len` diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index f04b285037dff403428ed74fe90eac60339fe36b..6cc74cfb3246e9526e862f363590ce43e390ffaa 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" +#include "absl/strings/substitute.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api.h" @@ -128,6 +129,14 @@ const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) { return ret; } +char* TF_FunctionDebugString(TF_Function* func, size_t* len) { + const auto& debug_str = func->fdef.DebugString(); + *len = debug_str.size(); + char* ret = static_cast(malloc(*len + 1)); + memcpy(ret, debug_str.c_str(), *len + 1); + return ret; +} + // On success, returns a set of TF_Function instances from `text_proto` of // GraphDef type. These functions must be deleted by calling TF_DeleteFunction. // @@ -8737,6 +8746,12 @@ static void CheckOk(TF_Status* status) { void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) { auto* status = TF_NewStatus(); + if (!TFE_TensorHandleIsConcrete(handle)) { + VLOG(1) << "Symbolic tensor: " << handle; + TF_DeleteStatus(status); + return; + } + TF_Tensor* t = TFE_TensorHandleResolve(handle, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -8748,6 +8763,11 @@ void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) { TF_DeleteStatus(status); } +void TFE_OpPrintDebugString(TFE_Op* op) { + VLOG(1) << "TFE_OpPrintDebugString() over " << op; + LOG(INFO) << op->operation.DebugString(); +} + struct TFE_ExecuteOpNotification { TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {} tensorflow::Notification n; @@ -8941,3 +8961,161 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, } status->status = EnableCollectiveOps(server_def, ctx); } + +std::string tensorflow::getTF_OutputDebugString(TF_Output node) { + return absl::Substitute("TF_Output($0, $1)", node.oper, node.index); +} + +using tensorflow::getTF_OutputDebugString; + +TFE_TensorHandle* TFE_NewTensorHandleFromTFOutput(TF_Output t, + TF_DataType dtype) { + auto ret = new TFE_TensorHandle(t, dtype); + VLOG(1) << "Storing TFOutput " << getTF_OutputDebugString(t) + << " into tensor handle " << ret << " with internal handle " + << ret->handle; + return ret; +} + +unsigned char TFE_TensorHandleIsConcrete(TFE_TensorHandle* handle) { + assert(handle->handle != nullptr); + return handle->handle->getSymbolicTensor() == nullptr; +} + +TF_Output TFE_GetTFOutputFromTensorHandle(TFE_TensorHandle* handle, + TF_Status* status) { + if (TFE_TensorHandleIsConcrete(handle)) { + status->status = + tensorflow::errors::Internal("Not a symbolic tensor: ", handle); + return TF_Output{nullptr, -1}; + } + + auto* sym_tensor = handle->handle->getSymbolicTensor(); + CHECK(sym_tensor != nullptr); + auto ret = TF_Output{sym_tensor->oper, sym_tensor->index}; + VLOG(1) << "Retrieving " << getTF_OutputDebugString(ret) + << " from tensor handle " << handle; + CHECK_GE(sym_tensor->index, 0); + return ret; +} + +TFE_TraceContext* TFE_NewTraceContext(TF_Graph* graph) { + return new TFE_TraceContext(graph); +} + +void TFE_DeleteTraceContext(TFE_TraceContext* trace_ctx) { delete trace_ctx; } + +// If `handle` is already symbolic, return it. Otherwise map it to a new +// symbolic tensor (a PlaceHolder op) and return that. +static TF_Output getOrCreateSymbolicTensor(TFE_TraceContext* trace_ctx, + tensorflow::TensorHandle* handle, + TF_Status* status) { + VLOG(1) << "Getting symbolic tensor for input tensor handle " << handle + << ": " << handle->DebugString(); + + auto* sym_tensor = handle->getSymbolicTensor(); + if (sym_tensor != nullptr) { + auto ret = TF_Output{sym_tensor->oper, sym_tensor->index}; + VLOG(1) << "This handle is a symbolic tensor " << sym_tensor << ": " + << getTF_OutputDebugString(ret); + return ret; + } + + auto find_it = trace_ctx->input_tensor_map.find(handle); + if (find_it != trace_ctx->input_tensor_map.end()) { + VLOG(1) << "There exists a map entry from this concrete tensor to: " + << getTF_OutputDebugString(find_it->second); + return find_it->second; + } + + auto node_name = tensorflow::strings::StrCat("additional_input_", + trace_ctx->node_counter++); + VLOG(1) << "Adding a place holder node named " << node_name; + auto* desc = + TF_NewOperation(trace_ctx->graph, "Placeholder", node_name.c_str()); + TF_SetAttrType(desc, "dtype", + static_cast(handle->dtype) /*TF_FLOAT*/); + auto* result = TF_FinishOperation(desc, status); + if (!status->status.ok()) { + return TF_Output{nullptr, -1}; + } + + auto ret = TF_Output{result, 0}; + VLOG(1) << "Creating a new map entry to map to: " + << getTF_OutputDebugString(ret); + trace_ctx->input_tensor_map[handle] = ret; + // `handle` could be destroyed before it's read from `input_tensor_map` (say + // during a subsequent TFE_FinalizeInputTensorsFromTraceContext() call), so we + // increment its ref count to extend its life span to that of `trace_ctx`. + handle->Ref(); + VLOG(1) << "Ref count for handle " << handle + << " is 1?: " << handle->RefCountIsOne(); + return ret; +} + +void TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx, + TFE_TensorHandle** retvals, int* num_retvals, + TF_Status* status) { + VLOG(1) << "Calling TFE_AddEagerOpToGraph() with op " << op << ": " + << op->operation.DebugString(); + + const auto& op_type = op->operation.Name(); + auto op_name = + tensorflow::strings::StrCat(op_type, "_", trace_ctx->node_counter++); + auto* desc = + TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str()); + for (auto* input : op->operation.Inputs()) { + auto symbolic_input = getOrCreateSymbolicTensor(trace_ctx, input, status); + if (!status->status.ok()) return; + TF_AddInput(desc, symbolic_input); + } + + VLOG(1) << "Adding attrs."; + // TODO(hongm): add attrs + + auto* graph_op = TF_FinishOperation(desc, status); + if (!status->status.ok()) return; + + VLOG(1) << "Op finalized; setting return tensors."; + *num_retvals = TF_OperationNumOutputs(graph_op); + VLOG(1) << "This op has " << *num_retvals << " outputs."; + for (int i = 0; i < *num_retvals; ++i) { + auto output = TF_Output{graph_op, i}; + auto dtype = TF_OperationOutputType(output); + retvals[i] = TFE_NewTensorHandleFromTFOutput(output, dtype); + } +} + +int TFE_FinalizeInputTensorsFromTraceContext(TFE_TraceContext* trace_ctx) { + if (trace_ctx->input_tensors == nullptr) { + trace_ctx->input_tensors = + new std::vector>(); + trace_ctx->input_tensors->reserve(trace_ctx->input_tensor_map.size()); + + for (auto input : trace_ctx->input_tensor_map) { + trace_ctx->input_tensors->emplace_back(input.first, input.second); + } + } + return trace_ctx->input_tensor_map.size(); +} + +TF_Output TFE_GetInputGraphNodeFromTraceContext(TFE_TraceContext* trace_ctx, + unsigned int idx) { + CHECK(trace_ctx->input_tensors != nullptr); + CHECK(trace_ctx->input_tensors->size() > idx); + return trace_ctx->input_tensors->at(idx).second; +} + +TFE_TensorHandle* TFE_ConsumeInputConcreteTensorFromTraceContext( + TFE_TraceContext* trace_ctx, unsigned int idx) { + CHECK(trace_ctx->input_tensors != nullptr); + CHECK(trace_ctx->input_tensors->size() > idx); + auto* handle = trace_ctx->input_tensors->at(idx).first; + VLOG(1) << "Ref count for internal handle " << handle + << " is 1?: " << handle->RefCountIsOne(); + handle->Ref(); + auto* ret = new TFE_TensorHandle(handle); + VLOG(1) << "Returning a new tensor handle " << ret << ": " + << handle->DebugString(); + return ret; +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index e6d04d0c2b25a3f7b1ebf50c58268f003595a520..48ea0ec1ed78a071b7bf7c858881d943a3ff3acd 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -84,6 +84,15 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_CreateRunOptions( TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph, size_t* len); +// Returns the function content in a human-readable format, with length set in +// `len`. The format is subject to change in the future. +// The returned string is heap-allocated, and caller should call free() on it. +// +// Do not return const char*, because some foreign language binding +// (e.g. swift) cannot then call free() on the returned pointer. +TF_CAPI_EXPORT extern char* TF_FunctionDebugString(TF_Function* func, + size_t* len); + // Creates a stack of data set + iterator nodes, currently hard-coded to return // a sequence of 3 float values <42.0, 43.0, 44.0> over 3 calls. On success, // returns the IteratorGetNext node, which caller can run or feed into an node. @@ -181,6 +190,8 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor( TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString( TFE_TensorHandle* handle); +TF_CAPI_EXPORT extern void TFE_OpPrintDebugString(TFE_Op* op); + typedef struct TFE_ExecuteOpNotification TFE_ExecuteOpNotification; // Allows invoking a kernel asynchronously, and explicitly returns a @@ -255,6 +266,55 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, const void* proto, size_t proto_len, TF_Status* status); + +// Create a symbolic tensor from the input graph node. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromTFOutput( + TF_Output t, TF_DataType data_type); + +// Returns 0 if the input tensor handle represents a symbolic tensor (i.e., a +// graph node). Otherwise returns non-0. +TF_CAPI_EXPORT extern unsigned char TFE_TensorHandleIsConcrete( + TFE_TensorHandle* handle); + +// If `handle` is a symbolic tensor, return the corresponding graph node +// represented by TF_Output. Otherwise, return an error status. +TF_CAPI_EXPORT extern TF_Output TFE_GetTFOutputFromTensorHandle( + TFE_TensorHandle* handle, TF_Status* status); + +typedef struct TFE_TraceContext TFE_TraceContext; + +// A trace context contains a trace graph, to which TFE_AddEagerOpToGraph() +// calls add graph nodes as a way to symbolically execute the eager ops. +// +// It also contains a hash map from concrete input tensors to symbolic +// tensors. That map will be used to create input tensors to the trace graph. +TF_CAPI_EXPORT extern TFE_TraceContext* TFE_NewTraceContext(TF_Graph* graph); + +TF_CAPI_EXPORT extern void TFE_DeleteTraceContext(TFE_TraceContext* trace_ctx); + +// Symbolically executes `op`, by adding a corresponding node to the graph +// associated with `trace_ctx`. This graph node outputs a set of symbolic +// tensors in `retvals` and `num_retvals`. +TF_CAPI_EXPORT extern void TFE_AddEagerOpToGraph(TFE_Op* op, + TFE_TraceContext* trace_ctx, + TFE_TensorHandle** retvals, + int* num_retvals, + TF_Status* status); + +// Finalizes the trace graph and its inputs, and returns the number of inputs. +// After this call, the next two APIs can be called to iterate over the input +// tensors. +TF_CAPI_EXPORT extern int TFE_FinalizeInputTensorsFromTraceContext( + TFE_TraceContext* trace_ctx); + +TF_CAPI_EXPORT extern TF_Output TFE_GetInputGraphNodeFromTraceContext( + TFE_TraceContext* trace_ctx, unsigned int idx); + +// Each input tensor should be consumed at most once. +TF_CAPI_EXPORT extern TFE_TensorHandle* +TFE_ConsumeInputConcreteTensorFromTraceContext(TFE_TraceContext* trace_ctx, + unsigned int idx); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index daa7701b7fe7e8ce757b6504329cf6434ad39778..4cfcf2ef3b2ccd9d8aedaf8efa4a31ac12d91c1b 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_test_util.h" @@ -296,5 +297,73 @@ TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Blocking) { TF_DeleteStatus(status); } +TEST(CAPI_EXPERIMENTAL, SymbolicTensor) { + TF_Status* status = TF_NewStatus(); + auto node = TF_Output{nullptr, 1}; + auto* sym_handle = TFE_NewTensorHandleFromTFOutput(node, TF_FLOAT); + TFE_TensorHandlePrintDebugString(sym_handle); + CHECK_EQ(TFE_TensorHandleDataType(sym_handle), TF_FLOAT); + ASSERT_FALSE(TFE_TensorHandleIsConcrete(sym_handle)); + auto same_node = TFE_GetTFOutputFromTensorHandle(sym_handle, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(same_node.oper, node.oper); + ASSERT_EQ(same_node.index, node.index); + TFE_DeleteTensorHandle(sym_handle); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + ASSERT_TRUE(TFE_TensorHandleIsConcrete(m)); + (void)TFE_GetTFOutputFromTensorHandle(m, status); + CHECK_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(m); + + TF_DeleteStatus(status); +} + +TEST(CAPI_EXPERIMENTAL, DebugPrintAndSymbolicExecution) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* op = MatMulOp(ctx, m, m); + + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpPrintDebugString(op); + + auto* graph = TF_NewGraph(); + auto* trace_ctx = TFE_NewTraceContext(graph); + TFE_TensorHandle* retvals[5]; + int num_retvals = 5; + // Symbolically execute this op, which adds a graph node to `trace_ctx`. + TFE_AddEagerOpToGraph(op, trace_ctx, retvals, &num_retvals, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + int num_inputs = TFE_FinalizeInputTensorsFromTraceContext(trace_ctx); + CHECK_EQ(num_inputs, 1); + auto input_sym_tensor = TFE_GetInputGraphNodeFromTraceContext(trace_ctx, + /*idx*/ 0); + + LOG(INFO) << tensorflow::getTF_OutputDebugString(input_sym_tensor); + auto handle = TFE_ConsumeInputConcreteTensorFromTraceContext(trace_ctx, + /*idx*/ 0); + TFE_TensorHandlePrintDebugString(handle); + TFE_DeleteTensorHandle(handle); + + CHECK_EQ(num_retvals, 1); + CHECK_EQ(TFE_TensorHandleDataType(retvals[0]), TF_FLOAT); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteTraceContext(trace_ctx); + TF_DeleteGraph(graph); + + TFE_DeleteTensorHandle(m); + TFE_DeleteOp(op); + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 5ba26d3c585350aa510f9970cbfc246a9a108543..73283d775639b297857b2a50007dc7c28b1f39a3 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -228,6 +228,8 @@ void RecordMutation(TF_Graph* graph, const TF_Operation& op, bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) LOCKS_EXCLUDED(session->graph->mu, session->mu); +std::string getTF_OutputDebugString(TF_Output node); + } // end namespace tensorflow #endif // TENSORFLOW_C_C_API_INTERNAL_H_ diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index d5934a10395ae094f65d3bc8b6cd7b94dbd32410..2be03bf0de6277fc63c353ad6dc63bec096a6993 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -163,6 +163,7 @@ TEST(CAPI, AllocateTensor) { EXPECT_EQ(dims[0], TF_Dim(t, 0)); EXPECT_EQ(dims[1], TF_Dim(t, 1)); EXPECT_EQ(num_bytes, TF_TensorByteSize(t)); + EXPECT_EQ(6, TF_TensorElementCount(t)); TF_DeleteTensor(t); } @@ -1467,6 +1468,41 @@ TEST(CAPI, DeletingNullPointerIsSafe) { TF_DeleteStatus(status); } +TEST(CAPI, TestBitcastFrom_Reshape) { + int64_t dims[] = {2, 3}; + TF_Tensor* a = + TF_AllocateTensor(TF_UINT64, dims, 2, 6 * TF_DataTypeSize(TF_UINT64)); + TF_Tensor* b = + TF_AllocateTensor(TF_UINT64, nullptr, 0, TF_DataTypeSize(TF_UINT64)); + EXPECT_NE(a, nullptr); + EXPECT_NE(b, nullptr); + + EXPECT_EQ(6, TF_TensorElementCount(a)); + EXPECT_EQ(1, TF_TensorElementCount(b)); + EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(a)); + EXPECT_EQ(TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(b)); + + int64_t new_dims[] = {3, 2}; + TF_Status* status = TF_NewStatus(); + TF_TensorBitcastFrom(a, TF_UINT64, b, new_dims, 2, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)); + TF_DeleteStatus(status); + + EXPECT_EQ(6, TF_TensorElementCount(a)); + EXPECT_EQ(6, TF_TensorElementCount(b)); + EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(a)); + EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(b)); + + // Check that a write to one tensor shows up in the other. + *(static_cast(TF_TensorData(a))) = 4; + EXPECT_EQ(4, *(static_cast(TF_TensorData(b)))); + *(static_cast(TF_TensorData(b))) = 6; + EXPECT_EQ(6, *(static_cast(TF_TensorData(a)))); + + TF_DeleteTensor(a); + TF_DeleteTensor(b); +} + REGISTER_OP("TestOpWithNoGradient") .Input("x: T") .Output("y: T") diff --git a/tensorflow/c/c_test.c b/tensorflow/c/c_test.c index c0ed5ccd15d9524e2c14630d8ef92f6b3ef9b059..b86d8eb8e300e02a3871ecd5f424a82c521b18fc 100644 --- a/tensorflow/c/c_test.c +++ b/tensorflow/c/c_test.c @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include -#include #include #include +#include #include #include @@ -32,7 +32,12 @@ void compute(void* kernel, TF_OpKernelContext* ctx) { TF_Status* s = TF_NewStatus(); TF_GetInput(ctx, 0, &input, s); TF_DeleteTensor(input); + + TF_DataType type; + TF_OpKernelContext_GetAttrType(ctx, "foobar", &type, s); + TF_DeleteStatus(s); + } // Exercises tensorflow's C API. @@ -68,6 +73,10 @@ int main(int argc, char** argv) { } fprintf(stderr, "wrote %s\n", full_path); free(full_path); + TF_CloseWritableFile(h, status); + if (TF_GetCode(status) != TF_OK) { + fprintf(stderr, "TF_CloseWritableFile failed: %s\n", TF_Message(status)); + } TF_StringStreamDone(s); TF_KernelBuilder* b = diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index c34a84fcfee9b6ba9a7be86ae16e2856a2d343c7..04dfefa6da28429b73856d392d94fa402ecda580 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -3,11 +3,19 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", - "tf_cuda_cc_test", - "tf_cc_test", "tf_copts", - "tfe_xla_copts", + "tf_cuda_cc_test", "tf_cuda_library", + "tfe_xla_copts", +) +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_additional_device_tracer_test_flags", + "tf_kernel_tests_linkstatic", +) +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", ) tf_cuda_library( @@ -62,6 +70,7 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime:remote_device", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/profiler/lib:eager_profiler", "//tensorflow/core:gpu_runtime", ], ) @@ -101,6 +110,7 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", + "//tensorflow/core/profiler/lib:eager_profiler", ], ) @@ -148,6 +158,88 @@ tf_cuda_cc_test( ], ) +tf_cuda_library( + name = "c_api_experimental", + srcs = [ + "c_api_experimental.cc", + ], + hdrs = ["c_api_experimental.h"], + copts = tf_copts() + tfe_xla_copts(), + visibility = ["//visibility:public"], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + ":c_api", + ":c_api_internal", + "//tensorflow/c:c_api", + "//tensorflow/c:c_api_internal", + "//tensorflow/core:core_cpu", + "//tensorflow/core/common_runtime/eager:attr_builder", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:eager_executor", + "//tensorflow/core/common_runtime/eager:execute", + "//tensorflow/core/common_runtime/eager:kernel_and_device", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/common_runtime/eager:copy_to_device_node", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], + }) + select({ + "//tensorflow:with_xla_support": [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/jit", + "//tensorflow/compiler/jit:xla_device", + ], + "//conditions:default": [], + }) + [ + "@com_google_absl//absl/memory", + "//tensorflow/core/common_runtime/eager:eager_operation", + "//tensorflow/core/distributed_runtime/eager:eager_client", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", + "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", + "//tensorflow/core/distributed_runtime:remote_device", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/profiler/rpc:profiler_server", + "//tensorflow/core:gpu_runtime", + ], +) + +tf_cuda_cc_test( + name = "c_api_experimental_test", + size = "small", + srcs = [ + "c_api_experimental_test.cc", + ], + args = + ["--heap_check=local"] + tf_additional_device_tracer_test_flags(), + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags() + ["nomac"], + deps = [ + ":c_api_experimental", + ":c_api_test_util", + "//tensorflow/c:c_test_util", + "//tensorflow/cc/profiler", + "//tensorflow/contrib/tpu/profiler:trace_events_proto_cc", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/profiler:protos_all_cc", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "tape", hdrs = ["tape.h"], diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 027d752f420238da867cb9d8c116640e1730caaa..af13f487af91594fedd4d5f77592682a6f98c34f 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -356,6 +356,8 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { if (h == nullptr) return; + VLOG(1) << "Deleting tensor handle " << h << " with internal handle " + << h->handle; if (h->handle) { h->handle->Unref(); } @@ -443,15 +445,15 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { return nullptr; } // TODO(agarwal): move this implementation inside TFE_TensorHandle. - tensorflow::Device* d = nullptr; - tensorflow::Device* op_device = nullptr; const tensorflow::Tensor* t = nullptr; - status->status = h->handle->TensorAndDevice(&t, &d, &op_device); - if (!status->status.ok()) return nullptr; tensorflow::TensorHandle* h_cpu = nullptr; - if (!IsCPU(d)) { - status->status = h->handle->CopyToDevice( - h->handle->Context(), h->handle->Context()->HostCPU(), &h_cpu); + tensorflow::Device* d = nullptr; + tensorflow::Device* op_device = nullptr; + + if (h->handle->IsRemote()) { + status->status = EagerCopyToDevice( + h->handle, h->handle->Context(), + h->handle->Context()->HostCPU()->name().c_str(), &h_cpu); if (!status->status.ok()) { return nullptr; } @@ -460,6 +462,22 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { h_cpu->Unref(); return nullptr; } + } else { + status->status = h->handle->TensorAndDevice(&t, &d, &op_device); + if (!status->status.ok()) return nullptr; + + if (!IsCPU(d)) { + status->status = h->handle->CopyToDevice( + h->handle->Context(), h->handle->Context()->HostCPU(), &h_cpu); + if (!status->status.ok()) { + return nullptr; + } + status->status = h_cpu->TensorAndDevice(&t, &d, &op_device); + if (!status->status.ok()) { + h_cpu->Unref(); + return nullptr; + } + } } TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status); if (h_cpu != nullptr) { @@ -696,6 +714,7 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { + VLOG(1) << "Calling TFE_Execute() on op " << op; tensorflow::gtl::InlinedVector handle_retvals( *num_retvals); status->status = @@ -738,6 +757,10 @@ void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, status->status = ctx->context.AddFunctionDef(function->fdef); } +unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { + return ctx->context.FindFunctionDef(name) != nullptr; +} + void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { ctx->context.SetShouldStoreMetadata(true); } @@ -774,7 +797,7 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, if (!status->status.ok()) return; tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf); - ctx->context.RunMetadataProto()->Clear(); + ctx->context.ClearRunMetadata(); } namespace { diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 120748ab763a3358b6e38e64bb3b6fd2ea32f7c3..044dfb7415b027b707af05a197fdb41fe1f6d2e5 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -393,6 +393,10 @@ TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status); +// Checks whether a function is registered under `name`. +TF_CAPI_EXPORT unsigned char TFE_ContextHasFunction(TFE_Context* ctx, + const char* name); + // Enables tracing of RunMetadata on the ops executed from this context. TF_CAPI_EXPORT extern void TFE_ContextEnableRunMetadata(TFE_Context* ctx); diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc index 52b0824552855860dfb138f3ac9a5d3afa7dc965..ffcd5ace0b98597363abe63201bf6c328a03212f 100644 --- a/tensorflow/c/eager/c_api_debug.cc +++ b/tensorflow/c/eager/c_api_debug.cc @@ -83,7 +83,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( } } - if (xla::ShapeUtil::IsTuple(padded_shape)) { + if (padded_shape.IsTuple()) { if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) { // Currently, the only case of XlaTensor containing a tuple shape is to // represent 64 bit ints, doubles, and complex numbers (we don't support @@ -99,7 +99,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0); const xla::Shape& shape1 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 1); - if (xla::ShapeUtil::IsTuple(shape0) || xla::ShapeUtil::IsTuple(shape1)) { + if (shape0.IsTuple() || shape1.IsTuple()) { status->status = tensorflow::errors::InvalidArgument( "XlaTensors should not contain nested tuples. Shape: ", padded_shape.DebugString()); diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc new file mode 100644 index 0000000000000000000000000000000000000000..dab17505643e791e6294a64247898ae23769a055 --- /dev/null +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -0,0 +1,52 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_experimental.h" + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/core/profiler/rpc/profiler_server.h" + +using tensorflow::string; + +void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { + op->operation.ConsumeInput(h->handle); +} + +TFE_Profiler* TFE_NewProfiler(TFE_Context* ctx) { + return new TFE_Profiler(ctx); +} + +void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; } + +void TFE_ProfilerSerializeToString(TFE_Context* ctx, TFE_Profiler* profiler, + TF_Buffer* buf, TF_Status* status) { + TFE_ContextAsyncWait(ctx, status); + if (!status->status.ok()) return; + string content; + status->status = profiler->profiler->SerializeToString(&content); + void* data = tensorflow::port::Malloc(content.length()); + content.copy(static_cast(data), content.length(), 0); + buf->data = data; + buf->length = content.length(); + buf->data_deallocator = [](void* data, size_t length) { + tensorflow::port::Free(data); + }; +} + +void TFE_StartProfilerServer(TFE_Context* ctx, int port) { + auto server_thread = tensorflow::StartProfilerServer(&ctx->context, port); + ctx->context.AddChildThread(std::move(server_thread)); +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h new file mode 100644 index 0000000000000000000000000000000000000000..8c85d0e51695fde09cf0e2bb3930f9173e6cfb54 --- /dev/null +++ b/tensorflow/c/eager/c_api_experimental.h @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ +#define TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, + TF_Status* status); + +// A profiler which will start profiling when creating the object and will stop +// when the object is destroyed. It will profile all operations run under the +// given TFE_Context. Multiple instance of it can be created, but at most one +// of them will profile for each TFE_Context. +// Thread-safety: TFE_Profiler is thread-safe. +typedef struct TFE_Profiler TFE_Profiler; + +TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler(TFE_Context* ctx); +TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler); + +// The output string is a binary string of tensorflow.tpu.Trace. User can write +// the string to file for offline analysis by tensorboard. +TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Context* ctx, + TFE_Profiler* profiler, + TF_Buffer* buf, + TF_Status* status); + +// Start a profiler grpc server which listens to specified port. It will start +// the server on its own thread. It can be shutdown by destructing TFE_Context. +// Creating multiple profiler server is allowed. The service defined in +// tensorflow/contrib/tpu/profiler/tpu_profiler.proto. Please use +// tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture tracable +// file following +// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace. +TF_CAPI_EXPORT extern void TFE_StartProfilerServer(TFE_Context* ctx, int port); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..af55fee66e8708e39626da3b10b6dd2f73af92bb --- /dev/null +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -0,0 +1,104 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_experimental.h" + +#include +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/cc/profiler/profiler.h" +#include "tensorflow/contrib/tpu/profiler/trace_events.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +using tensorflow::string; + +namespace tensorflow { +namespace { + +static bool HasSubstr(absl::string_view base, absl::string_view substr) { + bool ok = str_util::StrContains(base, substr); + EXPECT_TRUE(ok) << base << ", expected substring " << substr; + return ok; +} + +void ExecuteWithProfiling(bool async) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_Context* ctx = TFE_NewContext(opts, status); + TFE_Profiler* profiler = TFE_NewProfiler(ctx); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* matmul = MatMulOp(ctx, m, m); + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + + // Run op on GPU if it is present. + string gpu_device_name; + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { + TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + const char* device_name = TFE_OpGetDevice(matmul, status); + ASSERT_TRUE(strstr(device_name, "GPU:0") != nullptr); + } + + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m); + + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + TF_Buffer* profiler_result = TF_NewBuffer(); + TFE_ProfilerSerializeToString(ctx, profiler, profiler_result, status); + TFE_DeleteProfiler(profiler); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + tensorflow::tpu::Trace profile_proto; + EXPECT_TRUE(profile_proto.ParseFromString( + {reinterpret_cast(profiler_result->data), + profiler_result->length})); + string profile_proto_str = profile_proto.DebugString(); + if (!gpu_device_name.empty()) { + EXPECT_TRUE(HasSubstr(profile_proto_str, "GPU:0")); + // device name with "stream:all" is collected by Device Tracer. + EXPECT_TRUE(HasSubstr(profile_proto_str, "stream:all")); + } + EXPECT_TRUE(HasSubstr(profile_proto_str, "CPU:0")); + TF_DeleteBuffer(profiler_result); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + TF_DeleteStatus(status); +} +TEST(CAPI, ExecuteWithTracing) { ExecuteWithProfiling(false); } +TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithProfiling(true); } + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 67bc1bcd24605f8363d6a7c8d5d6a0836a42fc82..3b9e681194b7cebc61d9028525d200c692bbd529 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/profiler/lib/eager_profiler.h" #include "tensorflow/core/public/version.h" struct TFE_ContextOptions { @@ -82,6 +83,12 @@ struct TFE_TensorHandle { TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {} tensorflow::TensorHandle* handle; + + // Create a symbolic tensor. + TFE_TensorHandle(TF_Output t, TF_DataType dtype) + : handle(new tensorflow::TensorHandle( + tensorflow::OutputGraphNode{t.oper, t.index}, + static_cast(dtype))) {} }; struct TFE_TensorDebugInfo { @@ -100,6 +107,13 @@ struct TFE_Op { tensorflow::EagerOperation operation; }; +struct TFE_Profiler { + TFE_Profiler(TFE_Context* ctx) + : profiler(tensorflow::EagerProfiler::Create(&ctx->context)) {} + + std::unique_ptr profiler; +}; + namespace tensorflow { // Set an AttrValue on the op. Doesn't handle the list types. void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, @@ -107,4 +121,24 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, const char* attr_name, TF_Status* status); } // namespace tensorflow +struct TFE_TraceContext { + TF_Graph* const graph; + + unsigned int node_counter = 0; + // Each tensor handle will have its ref count incremented when it's added as a + // map key, and decremented when this object is destroyed. + std::map input_tensor_map; + std::vector>* input_tensors = + nullptr; + + TFE_TraceContext(TF_Graph* graph) : graph(graph) {} + + ~TFE_TraceContext() { + delete input_tensors; + for (auto input : input_tensor_map) { + input.first->Unref(); + } + } +}; + #endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 6b39b79ee82f9c7baaf856e573a42b7da65691e5..3d1ca4fb4b561a03ea9d879b1876fb1fd08a3139 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -175,13 +175,8 @@ void TestRemoteExecute(bool async) { TFE_Execute(matmul, &retvals[0], &num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - auto* retval_task0 = TFE_TensorHandleCopyToDevice( - retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status); + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteTensorHandle(retval_task0); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index 2a4eaecb6cf2740a522b1e849d1306ebde6c4577..9505bf9dda32b9a338b574f1d31ec555a5628c6a 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -48,9 +48,10 @@ TF_KernelBuilder* TF_NewKernelBuilder( } void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) { - DCHECK_NE(builder, nullptr); - delete builder->cc_builder; - delete builder; + if (builder != nullptr) { + delete builder->cc_builder; + delete builder; + } } namespace tensorflow { @@ -158,3 +159,44 @@ void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor, cc_ctx->set_output(i, cc_tensor); } } + +void TF_OpKernelConstruction_Failure(TF_OpKernelConstruction* ctx, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); + ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status)); + cc_ctx->CtxFailure(s); +} + +void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status)); + cc_ctx->CtxFailure(s); +} + +#define DEFINE_TF_GETATTR_(struct_name, func, c_type, cc_type) \ + void struct_name##_GetAttr##func(struct_name* ctx, const char* attr_name, \ + c_type* val, TF_Status* status) { \ + TF_SetStatus(status, TF_OK, ""); \ + cc_type v; \ + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); \ + ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v); \ + ::tensorflow::Set_TF_Status_from_Status(status, s); \ + if (s.ok()) { \ + *val = static_cast(v); \ + } \ + } + +#define DEFINE_TF_GETATTR(func, c_type, cc_type) \ + DEFINE_TF_GETATTR_(TF_OpKernelConstruction, func, c_type, cc_type) \ + DEFINE_TF_GETATTR_(TF_OpKernelContext, func, c_type, cc_type) + +DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType) + +TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + return static_cast(cc_ctx->expected_output_dtype(i)); +} + +int64_t TF_StepId(TF_OpKernelContext* ctx) { + return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id(); +} diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index cefc30bcdf89bdc14a4406299cc29f74153e77ac..b015d0103969355e8566242bfcc007f697c6ae18 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -111,6 +111,41 @@ TF_CAPI_EXPORT extern void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor, TF_Status* status); +// Notifies the given OpKernelConstruction that kernel construction has failed. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_Failure( + TF_OpKernelConstruction* ctx, TF_Status* status); + +// Notifies the given OpKernelContext that the kernel's compute function has +// failed. +TF_CAPI_EXPORT extern void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, + TF_Status* status); + +// Returns the expected output data type of the ith output. If i < 0 or +// i >= TF_NumOutputs(ctx), the program aborts. +TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType( + TF_OpKernelContext* ctx, int i); + +// Returns the step ID of the given context. +TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx); + +// Interprets the named kernel construction attribute as a TF_DataType and +// places it into *val. *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// TF_DataType, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrType( + TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* val, + TF_Status* status); + +// Interprets the named kernel context attribute as a TF_DataType and places it +// into *val. *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// TF_DataType, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelContext_GetAttrType( + TF_OpKernelContext* ctx, const char* attr_name, TF_DataType* val, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc index e659ee3c3d258a626ccf03a782ec031b5a703a48..0d2954717e7a83c102a35815809a554e3a917e07 100644 --- a/tensorflow/c/kernels_test.cc +++ b/tensorflow/c/kernels_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/kernels.h" #include "tensorflow/c/c_api.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def.pb_text.h" #include "tensorflow/core/framework/op.h" @@ -41,6 +42,19 @@ static void* MyCreateFunc(TF_OpKernelConstruction* ctx) { static void MyComputeFunc(void* kernel, TF_OpKernelContext* ctx) { struct MyCustomKernel* s = static_cast(kernel); s->compute_called = true; + if (ctx != nullptr) { + TF_Status* status = TF_NewStatus(); + + EXPECT_EQ(43, TF_StepId(ctx)); + + // Exercise attribute reads. + TF_DataType type; + TF_OpKernelContext_GetAttrType(ctx, "SomeDataTypeAttr", &type, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + EXPECT_EQ(TF_FLOAT, type); + + TF_DeleteStatus(status); + } } static void MyDeleteFunc(void* kernel) { @@ -61,6 +75,11 @@ static std::unique_ptr GetFakeKernel(const char* device_name, def.set_device(device_name); def.add_input("input1"); def.add_input("input2"); + + AttrValue v; + v.set_type(DataType::DT_FLOAT); + (*def.mutable_attr())["SomeDataTypeAttr"] = v; + return CreateOpKernel(DeviceType(device_name), nullptr, nullptr, def, 1, status); } @@ -75,7 +94,8 @@ TEST(TestKernel, TestRegisterKernelBuilder) { REGISTER_OP(op_name) .Input("input1: double") .Input("input2: uint8") - .Output("output1: uint8"); + .Output("output1: uint8") + .Attr("SomeDataTypeAttr: type"); TF_KernelBuilder* builder = TF_NewKernelBuilder( op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc); @@ -126,7 +146,8 @@ TEST(TestKernel, TestInputAndOutputCount) { REGISTER_OP(op_name) .Input("input1: double") .Input("input2: uint8") - .Output("output1: uint8"); + .Output("output1: uint8") + .Attr("SomeDataTypeAttr: type"); static int num_inputs = 0; static int num_outputs = 0; @@ -155,6 +176,8 @@ TEST(TestKernel, TestInputAndOutputCount) { TF_SetOutput(ctx, 24, input, s); EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + EXPECT_EQ(TF_UINT8, TF_ExpectedOutputDataType(ctx, 0)); + TF_DeleteStatus(s); if (input != nullptr) { TF_DeleteTensor(input); @@ -175,6 +198,7 @@ TEST(TestKernel, TestInputAndOutputCount) { OpKernelContext::Params p; DummyDevice dummy_device(nullptr, false); p.device = &dummy_device; + p.step_id = 43; Tensor t(tensorflow::uint8(123)); @@ -200,4 +224,8 @@ TEST(TestKernel, TestInputAndOutputCount) { } } +TEST(TestKernel, DeleteKernelBuilderIsOkOnNull) { + TF_DeleteKernelBuilder(nullptr); +} + } // namespace tensorflow diff --git a/tensorflow/cc/profiler/BUILD b/tensorflow/cc/profiler/BUILD index cf65fe1ab99b49207a64e86310178141b30d07d7..e9838d9aba6554b40082187057851e9c896f8352 100644 --- a/tensorflow/cc/profiler/BUILD +++ b/tensorflow/cc/profiler/BUILD @@ -10,7 +10,7 @@ tf_cuda_cc_test( name = "profiler_test", srcs = ["profiler_test.cc"], tags = [ - "noguitar", # b/77649654 + "nogpu", # b/77649654 ], deps = [ ":profiler", diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 85d3dd01fa51b3c3ba6fcbf5faac03f1ff5630e2..10f7abf09e925c0c31cfd595ecee4605f189476f 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/reader.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/monitoring/sampler.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" @@ -42,9 +43,28 @@ auto* load_latency = monitoring::Counter<1>::New( "/tensorflow/cc/saved_model/load_latency", "Latency in microseconds for SavedModels that were successfully loaded.", "model_path"); +auto* load_latency_by_stage = monitoring::Sampler<2>::New( + { + "/tensorflow/cc/saved_model/load_latency_by_stage", // metric name + "Distribution of wall time spent (in microseconds) in each stage " + "(restore graph from disk, run init graph op, etc) when loading the " + "model", + "model_path", + "stage", + }, + // Scale of 10, power of 1.8 with bucket count 33 (~20 minutes). + monitoring::Buckets::Exponential(10, 1.8, 33)); + constexpr char kLoadAttemptFail[] = "fail"; constexpr char kLoadAttemptSuccess[] = "success"; +uint64 GetLatencyMicroseconds(const uint64 start_microseconds) { + const uint64 end_microseconds = Env::Default()->NowMicros(); + // Avoid clock skew. + if (end_microseconds < start_microseconds) return 0; + return end_microseconds - start_microseconds; +} + Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, const SessionOptions& session_options, std::unique_ptr* session) { @@ -242,6 +262,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, const string& export_dir, const std::unordered_set& tags, SavedModelBundle* const bundle) { + const uint64 read_start_microseconds = Env::Default()->NowMicros(); TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags, &bundle->meta_graph_def)); @@ -256,12 +277,23 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, bundle->meta_graph_def.saver_def().restore_op_name(), bundle->meta_graph_def.saver_def().filename_tensor_name(), asset_file_defs, bundle->session.get())); + // Record walltime spent in restoring graph from disk, but postpone metric + // increments until graph init finishes. + const uint64 restore_graph_walltime = + GetLatencyMicroseconds(read_start_microseconds); + + const uint64 graph_init_start_microseconds = Env::Default()->NowMicros(); string init_op_name; TF_RETURN_IF_ERROR( GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name)); TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def, asset_file_defs, bundle->session.get(), init_op_name)); + load_latency_by_stage->GetCell(export_dir, "restore_graph") + ->Add(restore_graph_walltime); + // Record wall time spent in init op. + load_latency_by_stage->GetCell(export_dir, "init_graph") + ->Add(GetLatencyMicroseconds(graph_init_start_microseconds)); return Status::OK(); } @@ -275,16 +307,10 @@ Status LoadSavedModel(const SessionOptions& session_options, const uint64 start_microseconds = Env::Default()->NowMicros(); const Status status = LoadSavedModelInternal(session_options, run_options, export_dir, tags, bundle); - const uint64 load_latency_microsecs = [&]() -> uint64 { - const uint64 end_microseconds = Env::Default()->NowMicros(); - // Avoid clock skew. - if (end_microseconds < start_microseconds) return 0; - return end_microseconds - start_microseconds; - }(); auto log_and_count = [&](const string& status_str) { LOG(INFO) << "SavedModel load for tags { " << str_util::Join(tags, " ") << " }; Status: " << status_str << ". Took " - << load_latency_microsecs << " microseconds."; + << GetLatencyMicroseconds(start_microseconds) << " microseconds."; load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1); }; if (status.ok()) { @@ -292,7 +318,8 @@ Status LoadSavedModel(const SessionOptions& session_options, } else { log_and_count(kLoadAttemptFail); } - load_latency->GetCell(export_dir)->IncrementBy(load_latency_microsecs); + load_latency->GetCell(export_dir) + ->IncrementBy(GetLatencyMicroseconds(start_microseconds)); return status; } diff --git a/tensorflow/compat_template_v1.__init__.py b/tensorflow/compat_template_v1.__init__.py index b966c22b2319aef3b87ef54a283911718d37cf84..9549a71c41a0ba2aac58abd8cfb182aa4eaf3b4f 100644 --- a/tensorflow/compat_template_v1.__init__.py +++ b/tensorflow/compat_template_v1.__init__.py @@ -28,7 +28,8 @@ from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import from tensorflow.python.tools import component_api_helper as _component_api_helper _component_api_helper.package_hook( parent_package_str=__name__, - child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) + child_package_str=( + 'tensorflow_estimator.python.estimator.api._v1.estimator')) _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=('tensorflow.python.keras.api._v1.keras')) diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index ab1c1be344e2257721507543bc7647d4ff4becb2..d016632da2a9d7c2c2f81c02dd573787a0502923 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -129,7 +129,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type)); std::vector dim_vars; string dim_sizes, indices; - if (xla::ShapeUtil::Rank(shape) == 0 || + if (shape.rank() == 0 || (shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) { dim_sizes = "[1]"; indices = "[0]"; @@ -384,8 +384,9 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, // calling HloProfilePrinter::profile_counters_size. const string assign_profile_counters_size = opts.gen_hlo_profile_printer_data - ? "data->set_profile_counters_size(" - "data->hlo_profile_printer_data()->profile_counters_size());" + ? "set_static_data_profile_counters_size(data, " + "get_static_data_hlo_profile_printer_data(data)->" + "profile_counters_size());" : ""; // Use a poor-man's text templating mechanism; first populate the full header @@ -449,7 +450,7 @@ extern "C" void {{ENTRY}}( // arg bytes aligned: {{ARG_BYTES_ALIGNED}} // temp bytes total: {{TEMP_BYTES_TOTAL}} // temp bytes aligned: {{TEMP_BYTES_ALIGNED}} -class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { +class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. static constexpr size_t kNumArgs = {{ARG_NUM}}; @@ -464,16 +465,17 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ XlaCompiledCpuFunction::StaticData* data = new XlaCompiledCpuFunction::StaticData; - data->set_raw_function({{ENTRY}}); - data->set_buffer_infos(BufferInfos()); - data->set_num_buffers(kNumBuffers); - data->set_arg_index_table(ArgIndexToBufferIndex()); - data->set_num_args(kNumArgs); - data->set_result_index(kResultIndex); - data->set_arg_names(StaticArgNames()); - data->set_result_names(StaticResultNames()); - data->set_program_shape(StaticProgramShape()); - data->set_hlo_profile_printer_data(StaticHloProfilePrinterData()); + set_static_data_raw_function(data, {{ENTRY}}); + set_static_data_buffer_infos(data, BufferInfos()); + set_static_data_num_buffers(data, kNumBuffers); + set_static_data_arg_index_table(data, ArgIndexToBufferIndex()); + set_static_data_num_args(data, kNumArgs); + set_static_data_result_index(data, kResultIndex); + set_static_data_arg_names(data, StaticArgNames()); + set_static_data_result_names(data, StaticResultNames()); + set_static_data_program_shape(data, StaticProgramShape()); + set_static_data_hlo_profile_printer_data( + data, StaticHloProfilePrinterData()); {{ASSIGN_PROFILE_COUNTERS_SIZE}} return data; }(); diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 968afad65ed6d4b5510687df484b7ce6743f6a85..35994fc785d3e1d5e883c49bec96de315e189d2e 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -59,7 +59,7 @@ namespace bar { // arg bytes aligned: 192 // temp bytes total: 126 // temp bytes aligned: 320 -class MyClass : public tensorflow::XlaCompiledCpuFunction { +class MyClass final : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. static constexpr size_t kNumArgs = 2; @@ -74,16 +74,17 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ XlaCompiledCpuFunction::StaticData* data = new XlaCompiledCpuFunction::StaticData; - data->set_raw_function(entry_point); - data->set_buffer_infos(BufferInfos()); - data->set_num_buffers(kNumBuffers); - data->set_arg_index_table(ArgIndexToBufferIndex()); - data->set_num_args(kNumArgs); - data->set_result_index(kResultIndex); - data->set_arg_names(StaticArgNames()); - data->set_result_names(StaticResultNames()); - data->set_program_shape(StaticProgramShape()); - data->set_hlo_profile_printer_data(StaticHloProfilePrinterData()); + set_static_data_raw_function(data, entry_point); + set_static_data_buffer_infos(data, BufferInfos()); + set_static_data_num_buffers(data, kNumBuffers); + set_static_data_arg_index_table(data, ArgIndexToBufferIndex()); + set_static_data_num_args(data, kNumArgs); + set_static_data_result_index(data, kResultIndex); + set_static_data_arg_names(data, StaticArgNames()); + set_static_data_result_names(data, StaticResultNames()); + set_static_data_program_shape(data, StaticProgramShape()); + set_static_data_hlo_profile_printer_data( + data, StaticHloProfilePrinterData()); return data; }(); @@ -256,7 +257,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { static const xla::ProgramShapeProto* StaticProgramShape() { static const xla::ProgramShapeProto* kShape = []() { xla::ProgramShapeProto* proto = new xla::ProgramShapeProto; - proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 52); + proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 64); return proto; }(); return kShape; diff --git a/tensorflow/compiler/aot/codegen_test_o.golden b/tensorflow/compiler/aot/codegen_test_o.golden index ce8e5ec8c96a2c3696f14b8eea206d648182ecb5..7f7b96428572705f30144e6c95cd4cf9c44ce2a3 100644 Binary files a/tensorflow/compiler/aot/codegen_test_o.golden and b/tensorflow/compiler/aot/codegen_test_o.golden differ diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 64b861a73091642b03573543a5c55618bf33915d..7bac79ec062af7e790134286e34eda4e123e138a 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -50,7 +50,7 @@ def tfadd_with_ckpt(out_dir): y = variables.VariableV1(constant_op.constant([0]), name='y_saved') math_ops.add(x, y, name='x_y_sum') - init_op = variables.initialize_all_variables() + init_op = variables.global_variables_initializer() saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1) with session.Session() as sess: sess.run(init_op) @@ -65,7 +65,7 @@ def tfadd_with_ckpt_saver(out_dir): y = variables.VariableV1(constant_op.constant([0]), name='y_saved') math_ops.add(x, y, name='x_y_sum') - init_op = variables.initialize_all_variables() + init_op = variables.global_variables_initializer() saver = saver_lib.Saver(name='abcprefix', write_version=saver_pb2.SaverDef.V1) with session.Session() as sess: sess.run(init_op) diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 4051664c24cacad4a2d151ad3ac9009015900609..2abe3e29b78dbbe719637b13418704acc213d050 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -207,7 +207,7 @@ def tf_library( # # Note that setting the local=1 attribute on a *test target* causes the # test infrastructure to skip that test. However this is a genrule, not - # a test target, and runs with --genrule_strategy=forced_forge, meaning + # a test target, and runs with --strategy=Genrule=forced_forge, meaning # the local=1 attribute is ignored, and the genrule is still run. # # https://www.bazel.io/versions/master/docs/be/general.html#genrule diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index d548de8c44285f6d21dd778db464a31e1b19645b..0b6ab7e723d6e3a55da2f1c30b75f44cbdaa75bb 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -136,6 +136,10 @@ int main(int argc, char** argv) { tensorflow::string usage = tensorflow::tfcompile::kUsageHeader; usage += tensorflow::Flags::Usage(argv[0], flag_list); + if (argc > 1 && absl::string_view(argv[1]) == "--help") { + std::cerr << usage << "\n"; + return 0; + } bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); QCHECK(parsed_flags_ok) << "\n" << usage; diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index b9a87ba296abfc6b9d9aaeff3b3e26678e4e1b94..55e2e6d11f7a0984d2e1a40990c3d3db85bd1ff4 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -175,12 +175,18 @@ cc_library( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:stream_pool", + "//tensorflow/core:array_ops_op_lib", + "//tensorflow/core:control_flow_ops_op_lib", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:resource_variable_ops_op_lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", @@ -634,10 +640,10 @@ tf_cc_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", - "//tensorflow/core/grappler/optimizers/data:graph_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc index 48a23a4c1711ac88a329723c46559112d5a39dbd..390ffa694b6f127544d92f3024a02d877556aacd 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/jit/node_matchers.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 0562838f628c66b1eb03af9d2a5139c01dca31c5..0ef0d3db8c16e4b3f78d29aad5a2ae75a81d96f6 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -20,7 +20,10 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/hash/hash.h" @@ -222,29 +225,40 @@ class NotPredicate : public Predicate { std::array operands_; }; -// Represents an infinite list of predicates. +// Represents the liveness of an induction variable. For users inside the loop +// this represents the "current" liveness of the induction variable. For users +// outside the loop it represents the "last" liveness of the induction variable. // -// An AndRecurrence with start = S and step = X is printed as {S,&,X} and stands -// for the list of predicates: +// More concretely, an and recurrence {S,&,X} represents the liveness of V +// in the following graph: // -// S, S & GenSym(X,1), S & GenSym(X,1) & GenSym(X,2), ... +// V = Merge(S', V_NextIt) +// V = Op(V, X') +// V_NextIt = NextIteration(V) // -// where GenSym(, ) renames every SymbolPredicate in -// by appending to it, in effect creating a "fresh" symbol. -// This means {P,&,Q} is not equal to "P on the first iteration; P&Q on -// subsequent iterations". +// where Predicate(S') = S and Predicate(X') = X. +// +// `X` may contain symbolic predicates and the operations corresponding to these +// symbolic predicates are either in frame `loop` or outside it. The symbols +// that are inside frame `loop` are loop variant (i.e. can have different +// liveness in each loop iteration) and the symbols that are outside frame +// `loop` are loop invariant (i.e. have the same liveness across all +// iterations). class AndRecurrencePredicate : public Predicate { public: - explicit AndRecurrencePredicate(Predicate* start, Predicate* step) - : Predicate(HashPredicateSequence(Kind::kAndRecurrence, {start, step})), - operands_({start, step}) {} + explicit AndRecurrencePredicate(Predicate* start, Predicate* step, + std::vector frame) + : Predicate(Hash(start, step, frame)), + operands_({start, step}), + frame_(std::move(frame)) {} Predicate* start() const { return operands_[0]; } Predicate* step() const { return operands_[1]; } + absl::Span frame() const { return frame_; } string ToString() const override { return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(), - "}"); + "}<", absl::StrJoin(frame(), ";"), ">"); } Kind kind() const override { return Kind::kAndRecurrence; } @@ -255,6 +269,17 @@ class AndRecurrencePredicate : public Predicate { private: std::array operands_; + std::vector frame_; + + static int64 Hash(Predicate* start, Predicate* step, + const std::vector& frame) { + uint64 frame_hash = 0; + for (const string& sub_frame : frame) { + frame_hash = Hash64Combine(Hash64(sub_frame), frame_hash); + } + return Hash64Combine( + HashPredicateSequence(Kind::kAndRecurrence, {start, step}), frame_hash); + } }; // Represents an uninterpreted symbol in a logical predicate. @@ -281,7 +306,7 @@ class SymbolPredicate : public Predicate { // "tensor_id() is live and evaluates to true". // // If `must_be_true()` is false then this SymbolPredicate represents the - // proposition "tensor_id() is live (and may evalutate to any value)" + // proposition "tensor_id() is live (and may evaluate to any value)" TensorId tensor_id() const { return tensor_id_; } bool must_be_true() const { return must_be_true_; } @@ -333,34 +358,58 @@ class PredicateFactory { } Predicate* MakeNotPredicate(Predicate* pred) { - SignatureForNot signature = pred; - auto it = interned_not_instances_.find(signature); - if (it == interned_not_instances_.end()) { - std::unique_ptr new_pred = Make(pred); - Predicate* new_pred_ptr = new_pred.get(); - interned_not_instances_.emplace(signature, std::move(new_pred)); - return new_pred_ptr; - } else { - return it->second.get(); + auto it = make_not_predicate_cache_.find(pred); + if (it != make_not_predicate_cache_.end()) { + return it->second; } + + Predicate* result = MakeNotPredicateImpl(pred); + + bool insert_successful = + make_not_predicate_cache_.insert({pred, result}).second; + (void)insert_successful; + DCHECK(insert_successful); + + return result; } - Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step) { - auto it = interned_and_rec_instances_.find({start, step}); + Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step, + std::vector frame) { + SignatureForAndRec signature(start, step, std::move(frame)); + auto it = interned_and_rec_instances_.find(signature); if (it != interned_and_rec_instances_.end()) { return it->second.get(); } - std::unique_ptr new_pred = - Make(start, step); + std::unique_ptr new_pred = Make( + std::get<0>(signature), std::get<1>(signature), std::get<2>(signature)); Predicate* new_pred_ptr = new_pred.get(); - CHECK(interned_and_rec_instances_ - .emplace(SignatureForAndRec(start, step), std::move(new_pred)) - .second); + bool inserted = + interned_and_rec_instances_.emplace(signature, std::move(new_pred)) + .second; + (void)inserted; + DCHECK(inserted); return new_pred_ptr; } - Predicate* MakeSymbolPredicate(TensorId tensor_id, bool must_be_true) { + Status MakeSymbolPredicate(Node* node, int output_idx, bool must_be_true, + Predicate** predicate) { + TensorId tensor_id(node->name(), output_idx); + + bool is_boolean_tensor = node->output_type(tensor_id.index()) == DT_BOOL; + TF_RET_CHECK(!must_be_true || is_boolean_tensor); + + if (node->type_string() == "Const" && must_be_true) { + const TensorProto* proto = nullptr; + TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "value", &proto)); + + Tensor tensor(proto->dtype()); + TF_RET_CHECK(tensor.FromProto(*proto)); + + *predicate = tensor.scalar()() ? MakeTrue() : MakeFalse(); + return Status::OK(); + } + SignatureForSymbol signature = {tensor_id, must_be_true}; auto it = interned_symbol_instances_.find(signature); if (it == interned_symbol_instances_.end()) { @@ -369,16 +418,63 @@ class PredicateFactory { Predicate* new_pred_ptr = new_pred.get(); interned_symbol_instances_.emplace(std::move(signature), std::move(new_pred)); - return new_pred_ptr; + *predicate = new_pred_ptr; } else { - return it->second.get(); + *predicate = it->second.get(); } + + return Status::OK(); } Predicate* MakeTrue() { return MakeAndPredicate({}); } Predicate* MakeFalse() { return MakeOrPredicate({}); } + ~PredicateFactory() { + DCHECK_EQ(stack_depth_, 0) << "Unnested IncrementStackDepth?"; + } + private: + Predicate* MakeNotPredicateImpl(Predicate* pred) { + IncrementStackDepth stack_frame(this); + if (!stack_frame.HasOverflowed()) { + if (Predicate* simplified = SimplifyUsingDeMorgan(pred)) { + return simplified; + } + + // ~~A => A + if (auto* not_pred = dynamic_cast(pred)) { + return not_pred->operand(); + } + } + + SignatureForNot signature = pred; + auto it = interned_not_instances_.find(signature); + if (it == interned_not_instances_.end()) { + std::unique_ptr new_pred = Make(pred); + Predicate* new_pred_ptr = new_pred.get(); + interned_not_instances_.emplace(signature, std::move(new_pred)); + return new_pred_ptr; + } else { + return it->second.get(); + } + } + + Predicate* SimplifyUsingDeMorgan(Predicate* pred) { + // ~(A & B & C & ...) => ~A | ~B | ~C | ~... + // ~(A | B | C | ...) -> ~A & ~B & ~C & ~... + Predicate::Kind kind = pred->kind(); + + if (kind == Predicate::Kind::kAnd || kind == Predicate::Kind::kOr) { + std::vector new_operands; + absl::c_transform(pred->GetOperands(), std::back_inserter(new_operands), + [&](Predicate* p) { return MakeNotPredicate(p); }); + return kind == Predicate::Kind::kOr ? MakeAndPredicate(new_operands) + : MakeOrPredicate(new_operands); + } + + return nullptr; + } + template std::unique_ptr Make(Args&&... args) { return std::unique_ptr( @@ -402,7 +498,8 @@ class PredicateFactory { using SignatureForAndOr = std::pair>; using SignatureForNot = Predicate*; - using SignatureForAndRec = std::pair; + using SignatureForAndRec = + std::tuple>; using SignatureForSymbol = std::pair; struct HashSignatureForAndOr { @@ -422,6 +519,36 @@ class PredicateFactory { } }; + // Used to limit recursion to avoid blowing up the stack and cap compile time. + class IncrementStackDepth { + public: + explicit IncrementStackDepth(PredicateFactory* parent) : parent_(parent) { + parent_->stack_depth_++; + } + + bool HasOverflowed() const { + const int kMaxStackDepth = 8; + return parent_->stack_depth_ >= kMaxStackDepth; + } + + ~IncrementStackDepth() { parent_->stack_depth_--; } + + private: + PredicateFactory* parent_; + }; + + // A cache for the MakeNotPredicate function. + // + // NB! This is *not* the same as `interned_not_instances_`. + // `interned_not_instances_` maps ensures pointer identity for `NotPredicate` + // instances, i.e., it ensures there at most one instance of Not(predicate) + // for any given predicate whereas `make_not_predicate_cache_` simply caches + // the result of the `MakeNotPredicate` function. The values in + // `interned_not_instances_` are always instance of `NotPredicate` whereas the + // values in `make_not_predicate_cache_` may not be (for instance it will map + // Not(Not(A)) to A). + absl::flat_hash_map make_not_predicate_cache_; + absl::flat_hash_map, HashSignatureForAndOr> interned_and_or_instances_; @@ -432,6 +559,7 @@ class PredicateFactory { absl::flat_hash_map, HashSignatureForSymbol> interned_symbol_instances_; + int stack_depth_ = 0; }; Predicate* PredicateFactory::MakeInternedAndOr( @@ -466,6 +594,13 @@ Predicate* PredicateFactory::MakeAndOrImpl( absl::Span operands, bool is_and) { Predicate::Kind pred_kind = is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr; + + IncrementStackDepth stack_frame(this); + if (stack_frame.HasOverflowed()) { + return MakeInternedAndOr( + std::vector(operands.begin(), operands.end()), pred_kind); + } + Predicate::Kind other_pred_kind = is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd; absl::flat_hash_set simplified_ops_set; @@ -494,16 +629,31 @@ Predicate* PredicateFactory::MakeAndOrImpl( // Simplify "A&~A=>False" and "A|~A=>True". absl::flat_hash_set negated_ops; - for (Predicate* op : simplified_ops) { - if (op->kind() == Predicate::Kind::kNot) { - negated_ops.insert(dynamic_cast(*op).operand()); - } - } - for (Predicate* op : simplified_ops) { if (negated_ops.count(op)) { + // Simple case: + // + // A & ~A & ... == False + // A | ~A | ... == True return is_and ? MakeFalse() : MakeTrue(); } + + Predicate* negated_op = MakeNotPredicate(op); + if (negated_op->kind() == pred_kind) { + // Slightly more complicated case: + // + // (~A | ~B | ~C) & A & B & C & ... == + // ~(A & B & C) & (A & B & C) & ... == False + // + // (~A & ~B & ~C) | A | B | C | ... == + // ~(A | B | C) | (A | B | C) | ... == True + if (absl::c_all_of(negated_op->GetOperands(), [&](Predicate* p) { + return simplified_ops_set.contains(p); + })) { + return is_and ? MakeFalse() : MakeTrue(); + } + } + negated_ops.insert(negated_op); } // If all ops contain the same subop, then factor it out thanks to the @@ -619,6 +769,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { const Graph& graph_; absl::flat_hash_map predicate_map_; PredicateFactory predicate_factory_; + std::vector control_flow_info_; bool vlog_; }; @@ -661,9 +812,12 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n, TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); const Edge* pred_edge; TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge)); - Predicate* true_switch = predicate_factory_.MakeSymbolPredicate( - TensorId(pred_edge->src()->name(), pred_edge->src_output()), - /*must_be_true=*/true); + + Predicate* true_switch; + TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( + pred_edge->src(), pred_edge->src_output(), + /*must_be_true=*/true, &true_switch)); + Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch); // Output 0 is alive iff all inputs are alive and the condition is false. @@ -761,6 +915,23 @@ Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory, return found_sym ? predicate_factory->MakeAndPredicate(and_ops) : nullptr; } + +Status GetFullFrame(const Node* n, absl::Span cfi_infos, + std::vector* frame) { + int depth = 0; + for (const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()]; !n->IsSource(); + n = cfi_iter->parent_frame, cfi_iter = &cfi_infos[n->id()]) { + frame->push_back(cfi_iter->frame_name); + + if (depth++ > 5000) { + return errors::Internal( + "Frame of depth > 5000: Probably malformed graph or a bug in " + "BuildControlFlowInfo"); + } + } + + return Status::OK(); +} } // namespace Status DeadnessAnalysisImpl::HandleMerge(Node* n, @@ -783,8 +954,10 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, if (has_unvisited_backedge) { // We're visiting this merge for the first time and it has an unvisited // backedge. - Predicate* input_data_pred = predicate_factory_.MakeSymbolPredicate( - TensorId(n->name(), 0), /*must_be_true=*/false); + Predicate* input_data_pred; + TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( + n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred)); + SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, should_revisit); return Status::OK(); @@ -825,8 +998,10 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, Predicate* start = predicate_factory_.MakeOrPredicate(non_recurrent_inputs); - Predicate* and_rec = - predicate_factory_.MakeAndRecurrencePredicate(start, step); + std::vector frame; + TF_RETURN_IF_ERROR(GetFullFrame(n, control_flow_info_, &frame)); + Predicate* and_rec = predicate_factory_.MakeAndRecurrencePredicate( + start, step, std::move(frame)); SetPredicate(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit); return Status::OK(); } @@ -841,8 +1016,10 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n, // acquire a dead signal from a _Send. std::vector input_preds; TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); - input_preds.push_back(predicate_factory_.MakeSymbolPredicate( - TensorId(n->name(), 0), /*must_be_true=*/false)); + Predicate* signal_is_alive; + TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( + n, /*output_idx=*/0, /*must_be_true=*/false, &signal_is_alive)); + input_preds.push_back(signal_is_alive); SetPredicate(n, {0, Graph::kControlSlot}, predicate_factory_.MakeAndPredicate(input_preds), should_revisit); @@ -892,6 +1069,24 @@ Status DeadnessAnalysisImpl::Populate() { Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( absl::Span rpo) { + std::vector unreachable_nodes; + // Compute the loop structure of the graph. + TF_RETURN_IF_ERROR( + BuildControlFlowInfo(&graph_, &control_flow_info_, &unreachable_nodes)); + + // Do some opportunistic error checking: + if (!unreachable_nodes.empty()) { + if (unreachable_nodes.size() > 5) { + unreachable_nodes.erase(unreachable_nodes.begin() + 5, + unreachable_nodes.end()); + } + + return errors::InvalidArgument( + "Found unreachable nodes, most likely source and sink nodes not " + "connected: ", + absl::StrJoin(unreachable_nodes, ", ")); + } + // This an abstract interpretation over the deadness propagation semantics of // the graph executor. // diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 8a73101c184e6190921fd7729742922bd96f4bcf..16ee8f86d55c72785368ac2fd67635eba2fa7cd7 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -123,10 +123,9 @@ InductionVarInfo CreateInductionVariable(const Scope& root, Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1); Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10); Output loop_cond_expr = - ops::Less(root.WithOpName(prefix + "/less"), iv.output, final_value); - Output loop_cond = - ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr); - ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); + ops::Less(root.WithOpName(prefix + "/cond"), iv.output, final_value); + ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, + loop_cond_expr); ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), latch.output_false); Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"), @@ -140,7 +139,7 @@ InductionVarInfo CreateInductionVariable(const Scope& root, root.graph()->AddControlEdge(iv.output.node(), increment_by.node()); root.graph()->AddControlEdge(iv.output.node(), final_value.node()); - return {iv.output, loop_cond}; + return {iv.output, loop_cond_expr}; } InductionVarInfo CreateInductionVariable(const Scope& root, @@ -515,24 +514,27 @@ TEST(DeadnessAnalysisTest, Loop) { // In theory we should be able to tell that iv0/cond:0 and iv1/cond:0 // produce the same deadness. But we're not that smart today. - EXPECT_EQ(predicate_map[ControlOutputFor(iv0)], "{#true,&,*iv0/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv1)], "{#true,&,*iv1/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv2)], "{#true,&,*iv2/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv0)], + "{#true,&,*iv0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv1)], + "{#true,&,*iv1/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv2)], + "{#true,&,*iv2/cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "({#true,&,*iv1/cond:0} & {#true,&,*iv0/cond:0})"); + "({#true,&,*iv1/cond:0} & {#true,&,*iv0/cond:0})"); EXPECT_EQ(predicate_map[ControlOutputFor(add1)], - "({#true,&,*iv1/cond:0} & {#true,&,*iv2/cond:0})"); + "({#true,&,*iv1/cond:0} & {#true,&,*iv2/cond:0})"); } } TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) { Scope root = Scope::NewRootScope().ExitOnError(); - InductionVarInfo iv = CreateInductionVariable(root, "iv0", "frame", 0); + InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0); Output dependent_iv0 = - CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0) + CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0) .induction_var; Output dependent_iv1 = - CreateDependentLoopInvariantValue(root, "div1", "frame", iv.loop_cond, 0) + CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0) .induction_var; Output add0 = ops::Add(root.WithOpName("add0"), dependent_iv0, dependent_iv1); @@ -549,13 +551,13 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) { TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], - "{#true,&,*iv0/cond:0}"); + "{#true,&,*iv0/cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)], - "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)], - "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); } } @@ -595,32 +597,33 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) { TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { Scope root = Scope::NewRootScope().ExitOnError(); InductionVarInfo iv_outer = - CreateInductionVariable(root, "iv_outer", "frame", 0); + CreateInductionVariable(root, "iv_outer", "outer_loop", 0); + Output enter_constant_outer_loop = ops::internal::Enter( + root.WithOpName("constant_enter_outer_loop"), + ops::Const(root.WithOpName("constant"), 5), "outer_loop", + ops::internal::Enter::Attrs().IsConstant(true)); ops::Switch inner_value(root.WithOpName("outer_is_live"), - ops::Const(root.WithOpName("constant"), 5), - iv_outer.loop_cond); + enter_constant_outer_loop, iv_outer.loop_cond); InductionVarInfo iv_inner = CreateInductionVariable( - root, "iv_inner", "frame", - ops::internal::Enter(root.WithOpName("iv_inner/enter"), - inner_value.output_true, "frame_inner")); + root, "iv_inner", "inner_loop", inner_value.output_true); Output dependent_outer_iv0 = - CreateDependentLoopInvariantValue(root, "dependent_outer_iv0", "frame", - iv_outer.loop_cond, 0) + CreateDependentLoopInvariantValue(root, "dependent_outer_iv0", + "outer_loop", iv_outer.loop_cond, 0) .induction_var; Output dependent_outer_iv1 = - CreateDependentLoopInvariantValue(root, "dependent_outer_iv1", "frame", - iv_outer.loop_cond, 0) + CreateDependentLoopInvariantValue(root, "dependent_outer_iv1", + "outer_loop", iv_outer.loop_cond, 0) .induction_var; - Output dependent_inner_iv0 = - CreateDependentLoopInvariantValue(root, "dependent_inner_iv0", "frame", - iv_inner.loop_cond, dependent_outer_iv0) - .induction_var; - Output dependent_inner_iv1 = - CreateDependentLoopInvariantValue(root, "dependent_inner_iv1", "frame", - iv_inner.loop_cond, dependent_outer_iv1) - .induction_var; + Output dependent_inner_iv0 = CreateDependentLoopInvariantValue( + root, "dependent_inner_iv0", "inner_loop", + iv_inner.loop_cond, dependent_outer_iv0) + .induction_var; + Output dependent_inner_iv1 = CreateDependentLoopInvariantValue( + root, "dependent_inner_iv1", "inner_loop", + iv_inner.loop_cond, dependent_outer_iv1) + .induction_var; Output add0 = ops::Add(root.WithOpName("add0"), dependent_inner_iv0, dependent_inner_iv1); @@ -638,46 +641,50 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], - "{#true,&,*iv_outer/cond:0}"); + "{#true,&,*iv_outer/cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)], - "{(*iv_outer/cond:0 & {#true,&,*iv_outer/cond:0}),&," - "*iv_inner/cond:0}"); + "{({#true,&,*iv_outer/cond:0} & " + "*iv_outer/cond:0),&,*iv_inner/cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)], - "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," - "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + "{{#true,&,(iv_outer/iv:0 & " + "*iv_outer/cond:0)},&,(*iv_inner/cond:0 & " + "iv_inner/iv:0)}"); + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)], - "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," - "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + "{{#true,&,(iv_outer/iv:0 & " + "*iv_outer/cond:0)},&,(*iv_inner/cond:0 & " + "iv_inner/iv:0)}"); EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," - "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + "{{#true,&,(iv_outer/iv:0 & " + "*iv_outer/cond:0)},&,(*iv_inner/cond:0 & " + "iv_inner/iv:0)}"); } } TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) { Scope root = Scope::NewRootScope().ExitOnError(); - InductionVarInfo iv_outer_0 = - CreateInductionVariable(root, "iv_outer_0", "frame", 0); - ops::Switch inner_value_0(root.WithOpName("outer_0_is_live"), - ops::Const(root.WithOpName("constant"), 5), - iv_outer_0.loop_cond); - InductionVarInfo iv_inner_0 = CreateInductionVariable( - root, "iv_inner_0", "frame", - ops::internal::Enter(root.WithOpName("iv_inner_0/enter"), - inner_value_0.output_true, "frame_inner")); - - InductionVarInfo iv_outer_1 = - CreateInductionVariable(root, "iv_outer_1", "frame", 1); - ops::Switch inner_init_value_1(root.WithOpName("outer_1_is_live"), - ops::Const(root.WithOpName("constant"), 5), - iv_outer_1.loop_cond); - InductionVarInfo iv_inner_1 = CreateInductionVariable( - root, "iv_inner_1", "frame", - ops::internal::Enter(root.WithOpName("iv_inner_1/enter"), - inner_init_value_1.output_true, "frame_inner")); - Output add0 = ops::Add(root.WithOpName("add0"), iv_inner_0.induction_var, - iv_inner_1.induction_var); + + std::array outer_iv; + std::array inner_iv; + + for (int i : {0, 1}) { + InductionVarInfo iv_outer = + CreateInductionVariable(root, "iv_outer", "outer_loop", 0); + Output enter_constant_outer_loop = ops::internal::Enter( + root.WithOpName("constant_enter_outer_loop"), + ops::Const(root.WithOpName("constant"), 5), "outer_loop", + ops::internal::Enter::Attrs().IsConstant(true)); + ops::Switch inner_value(root.WithOpName("outer_is_live"), + enter_constant_outer_loop, iv_outer.loop_cond); + InductionVarInfo iv_inner = CreateInductionVariable( + root, "iv_inner", "inner_loop", inner_value.output_true); + + outer_iv[i] = iv_outer.induction_var; + inner_iv[i] = iv_inner.induction_var; + } + + Output add0 = ops::Add(root.WithOpName("add0"), inner_iv[0], inner_iv[1]); VLogGraphIfAsked(*root.graph()); @@ -692,21 +699,76 @@ TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); - EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer_0.induction_var)], - "{#true,&,*iv_outer_0/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner_0.induction_var)], - "{(*iv_outer_0/cond:0 & {#true,&,*iv_outer_0/cond:0}),&," - "*iv_inner_0/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer_1.induction_var)], - "{#true,&,*iv_outer_1/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner_1.induction_var)], - "{(*iv_outer_1/cond:0 & {#true,&,*iv_outer_1/cond:0}),&," - "*iv_inner_1/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "({(*iv_outer_1/cond:0 & {#true,&,*iv_outer_1/cond:0}),&," - "*iv_inner_1/cond:0} & " - "{(*iv_outer_0/cond:0 & {#true,&,*iv_outer_0/cond:0}),&," - "*iv_inner_0/cond:0})"); + EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[0])], + "{#true,&,*iv_outer/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[0])], + "{({#true,&,*iv_outer/cond:0} & " + "*iv_outer/cond:0),&,*iv_inner/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[1])], + "{#true,&,*iv_outer/cond_1:0}"); + EXPECT_EQ( + predicate_map[ControlOutputFor(inner_iv[1])], + "{({#true,&,*iv_outer/cond_1:0} & " + "*iv_outer/cond_1:0),&,*iv_inner/cond_1:0}"); + EXPECT_EQ( + predicate_map[ControlOutputFor(add0)], + "({({#true,&,*iv_outer/cond:0} & " + "*iv_outer/cond:0),&,*iv_inner/cond:0} & " + "{({#true,&,*iv_outer/cond_1:0} & " + "*iv_outer/cond_1:0),&,*iv_inner/cond_1:0})"); + } +} + +TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) { + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10); + InductionVarInfo iv_1 = CreateInductionVariable(root, "iv_1", "frame_1", 9); + + Output init = CreateSwitch(root, "init").output_true; + Output step = CreateSwitch(root, "step").output_true; + + std::array exits; + std::array next_iterations; + + for (int i : {0, 1}) { + Output init_enter = ops::internal::Enter( + root.WithOpName(absl::StrCat("init_enter_frame_", i)), init, + absl::StrCat("frame_", i), + ops::internal::Enter::Attrs().IsConstant(true)); + Output step_enter = ops::internal::Enter( + root.WithOpName(absl::StrCat("step_enter_frame_", i)), step, + absl::StrCat("frame_", i), + ops::internal::Enter::Attrs().IsConstant(true)); + + ops::Merge iv(root.WithOpName(absl::StrCat("expr_", i)), + {init_enter, init_enter}); + Output add = ops::Add(root.WithOpName(absl::StrCat("add_", i)), iv.output, + step_enter); + next_iterations[i] = ops::NextIteration( + root.WithOpName(absl::StrCat("expr_", i, "_next_iteration")), add); + EXPECT_TRUE( + root.graph() + ->UpdateEdge(next_iterations[i].node(), 0, iv.output.node(), 1) + .ok()); + exits[i] = ops::internal::Exit(root.WithOpName(absl::StrCat("exit_", i)), + iv.output); + } + + FixupSourceAndSinkEdges(root.graph()); + + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_NE(predicate_map[ControlOutputFor(exits[0])], + predicate_map[ControlOutputFor(exits[1])]); + EXPECT_NE(predicate_map[ControlOutputFor(exits[0])], ""); + EXPECT_NE(predicate_map[ControlOutputFor(exits[1])], ""); + + EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])], + predicate_map[ControlOutputFor(next_iterations[1])]); + EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])], ""); + EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[1])], ""); } } @@ -818,5 +880,82 @@ TEST(DeadnessAnalysisTest, RecvVsSwitchText) { EXPECT_EQ(predicate_map[logical_and_output_0], "(recv:0 & *recv:0)"); } +TEST(DeadnessAnalysisTest, DeMorgan) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output cond_0 = ops::Placeholder(root.WithOpName("cond_0"), DT_BOOL); + Output cond_1 = ops::Placeholder(root.WithOpName("cond_1"), DT_BOOL); + Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT); + + ops::Switch sw_0(root.WithOpName("switch_0"), value, cond_0); + ops::Switch sw_1(root.WithOpName("switch_1"), value, cond_1); + + Output and_0_1 = + ops::Add(root.WithOpName("and_0_1"), sw_0.output_true, sw_1.output_true); + + Output or_not0_not1 = ops::Merge(root.WithOpName("or_not0_not1"), + {sw_0.output_false, sw_1.output_false}) + .output; + + // Predicate(should_always_be_dead) = + // (A & B) & (~A | ~B) = (A & B) & ~(A & B) = False + Output should_always_be_dead = + ops::Add(root.WithOpName("should_always_be_dead"), and_0_1, or_not0_not1); + + // Predicate(should_always_be_dead) = + // (A & B) | (~A | ~B) = (A & B) | ~(A & B) = True + Output should_always_be_alive = + ops::Merge(root.WithOpName("should_always_be_alive"), + {and_0_1, or_not0_not1}) + .output; + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_dead)], "#false"); + EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_alive)], "#true"); +} + +TEST(DeadnessAnalysisTest, ConstantTrueSwitchCondition) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output constant_true = ops::Const(root.WithOpName("const_true"), true); + Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT); + ops::Switch sw(root.WithOpName("switch"), value, constant_true); + + Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false); + Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true); + + FixupSourceAndSinkEdges(root.graph()); + + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#false"); + EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#true"); +} + +TEST(DeadnessAnalysisTest, ConstantFalseSwitchCondition) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output constant_false = ops::Const(root.WithOpName("const_false"), false); + Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT); + ops::Switch sw(root.WithOpName("switch"), value, constant_false); + + Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false); + Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true); + + FixupSourceAndSinkEdges(root.graph()); + + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#true"); + EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#false"); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 03aba97bbe81a11f6366d118ee5bc573d0c6b31b..d0d7a3f3785469acd79a83b6897668f94fc6ea2e 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -1008,13 +1008,15 @@ Status Encapsulator::Subgraph::AddHostComputes( // subgraph. for (const auto& src_node : oc_subgraph.control_inputs) { Node* src_image = node_images.at(src_node); - graph_->AddControlEdge(src_image, host_compute); + graph_->AddControlEdge(src_image, host_compute, + /* allow_duplicates= */ true); } // Connect the _HostCompute node to its ancestor host compute nodes. for (const auto& ancestor_name : host_compute_ancestors) { Node* ancestor = host_compute_node[ancestor_name]; - graph_->AddControlEdge(ancestor, host_compute); + graph_->AddControlEdge(ancestor, host_compute, + /* allow_duplicates= */ true); } // Connect the consumers in the subgraph to the _HostCompute node. @@ -1031,7 +1033,8 @@ Status Encapsulator::Subgraph::AddHostComputes( // node. for (const auto& dst_node : oc_subgraph.control_outputs) { Node* dst_image = node_images.at(dst_node); - graph_->AddControlEdge(host_compute, dst_image); + graph_->AddControlEdge(host_compute, dst_image, + /* allow_duplicates= */ true); } } } @@ -1059,7 +1062,8 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) { if (sequencer_ != nullptr) { VLOG(2) << "ConnectSequencerToCallNode"; - graph_out->AddControlEdge(sequencer_, call_node_); + graph_out->AddControlEdge(sequencer_, call_node_, + /* allow_duplicates= */ true); } } @@ -1279,7 +1283,8 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( // completes. This has no effect on execution order but prevents the // RecvAtHost being pruned. TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out)); - graph_out->AddControlEdge(oc_subgraph->recv_at_host, sequencer_); + graph_out->AddControlEdge(oc_subgraph->recv_at_host, sequencer_, + true /* skip duplicates check */); return Status::OK(); } @@ -1336,7 +1341,8 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( // subgraph completes. This has no effect on execution order but prevents the // RecvAtHost being pruned. TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out)); - graph_out->AddControlEdge(oc_subgraph->send_from_host, sequencer_); + graph_out->AddControlEdge(oc_subgraph->send_from_host, sequencer_, + /* allow_duplicates= */ true); return Status::OK(); } @@ -1446,7 +1452,8 @@ Status Encapsulator::CopySubgraphEdges( src_func_id == dst_func_id) { Graph* g = subgraphs_[src_func_id].GetGraph(); if (edge->IsControlEdge()) { - g->AddControlEdge(src_image, dst_image); + g->AddControlEdge(src_image, dst_image, + /* allow_duplicates= */ true); } else { g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input()); } @@ -1732,7 +1739,8 @@ Status Encapsulator::CopyEdgeToOutputGraph( if (edges_added ->emplace(OutputTensor(src_image, -1), InputTensor(dst_image, -1)) .second) { - graph_out->AddControlEdge(src_image, dst_image); + graph_out->AddControlEdge(src_image, dst_image, + /* allow_duplicates= */ true); } return Status::OK(); @@ -1761,7 +1769,8 @@ Status Encapsulator::AddCallNodeDependencies(Graph* graph_out) { const string& subgraph = ancestors.first; for (const string& ancestor : ancestors.second) { graph_out->AddControlEdge(subgraphs_[ancestor].GetCallNode(), - subgraphs_[subgraph].GetCallNode()); + subgraphs_[subgraph].GetCallNode(), + /* allow_duplicates= */ true); } } return Status::OK(); @@ -2129,7 +2138,8 @@ Status CheckClusterDependencyForCycles( const string& ancestor, const string& successor, const std::unordered_map>& ancestors, const std::unordered_map& node_ancestors_map, - GraphCycles* cycle_detector, std::map* cycle_detector_map) { + GraphCycles* cycle_detector, + std::unordered_map* cycle_detector_map) { if (cycle_detector_map->find(ancestor) == cycle_detector_map->end()) { (*cycle_detector_map)[ancestor] = cycle_detector->NewNode(); } @@ -2173,7 +2183,7 @@ Status Encapsulator::FindClusterDependencies() { // We check that clusters are acyclic using this cycle detector. GraphCycles cycle_detector; // Map from cluster name to cycle detector node id. - std::map cycle_detector_map; + std::unordered_map cycle_detector_map; // Process the nodes in topologically-sorted order. std::vector nodes; GetReversePostOrder(*graph_in_, &nodes); @@ -2535,7 +2545,33 @@ Status EncapsulateSubgraphsPass::Run( std::vector* input_permutation, std::vector* output_permutation, NodeDef* node) { // Optimize the subgraph. - OptimizeGraph(flr, subgraph); + // Do not constant fold nodes that output DT_VARIANT type tensors. + // XLA does not support Const nodes of Variant type since it needs + // to know the original ops to be able to compile them to the relevant + // XLA form. + // TODO(srbs): This filter is a little conservative. E.g. a subgraph of + // the form: + // Const + // | + // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op + // | + // (Discard popped list) + // + // Would have been reduced to "Const -> Op" without this filter. + // However since we are only allowed to specify the filter at the "Node" + // level there is no good way to allow the above behavior. So we + // disallow any sort of constant folding on Variant nodes for now. + auto cf_consider_fn = [](const Node* n) { + for (const auto& output_arg : n->op_def().output_arg()) { + if (output_arg.type() == DT_VARIANT) { + return false; + } + } + return true; + }; + GraphOptimizer::Options graph_optimizer_options; + graph_optimizer_options.cf_consider_fn = cf_consider_fn; + OptimizeGraph(flr, subgraph, graph_optimizer_options); const int num_args = input_permutation->size(); std::vector const_args(num_args); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 8617beec004d0fe912155f054442c5b6249bb6b5..1f8ec09e19c01d0a8b2a3761135ed53dfb2ad3b0 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -32,6 +34,8 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { @@ -513,6 +517,18 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, s = PerformStaticShapeInferenceBeforeEncapsulation(graph.get()); if (!s.ok()) return s; + // Create FunctionLibraryRuntime. + SessionOptions session_options; + std::vector> devices; + TF_CHECK_OK(DeviceFactory::AddDevices( + session_options, "/job:localhost/replica:0/task:0", &devices)); + OptimizerOptions opts; + auto device_mgr = absl::make_unique(std::move(devices)); + auto pflr = absl::make_unique( + device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def.get(), + opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); + auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); + std::unique_ptr graph_out; s = EncapsulateSubgraphsInFunctions( "_encapsulate", /*outside_compilation_attribute=*/"", *graph, @@ -538,7 +554,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, std::map{}}); } s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters, - graph_out.get(), lib_def.get()); + graph_out.get(), flr, lib_def.get()); if (!s.ok()) return s; GraphDef graphdef_out; @@ -941,7 +957,9 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"c"}}, }, {{"f_0_retval_retval", "F:o:0"}}); @@ -1101,7 +1119,9 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", shape_inference_graph2}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O2"}}, + {"_outside_compilation_subgraph", "O2"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"F"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", @@ -1112,7 +1132,9 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", shape_inference_graph1}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"D"}}, }, {{"g_0_retval_retval", "outside_compilation_O2_host_compute:outputs:0"}, @@ -1244,7 +1266,9 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"D"}}, }, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, @@ -1269,7 +1293,9 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"g_0_retval_retval", "G:o:0"}, {"i_0_retval_retval", "I:o:0"}}); @@ -1397,7 +1423,9 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"D"}}, }, {{"f_0_retval_retval", "F:o:0"}}); @@ -1419,7 +1447,9 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"i_0_retval_retval", "I:o:0"}}); @@ -1527,7 +1557,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"f_0_retval_retval", "F:o:0"}}); @@ -1615,7 +1647,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"D"}}, }, {{"f_0_retval_retval", "F:o:0"}}); @@ -1716,7 +1750,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {"f_0_retval_retval", "F:o:0"}}); @@ -1821,7 +1857,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {"f_0_retval_retval", "F:o:0"}}); @@ -1949,7 +1987,9 @@ TEST(EncapsulateSubgraphsTest, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", shape_inference_graph1}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"F:o:0"}, @@ -1959,7 +1999,9 @@ TEST(EncapsulateSubgraphsTest, {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", shape_inference_graph2}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O2"}}}, + {"_outside_compilation_subgraph", "O2"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {"h_0_retval_retval", "H:o:0"}}); @@ -2082,7 +2124,9 @@ TEST(EncapsulateSubgraphsTest, {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O2"}}}, + {"_outside_compilation_subgraph", "O2"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, @@ -2092,7 +2136,9 @@ TEST(EncapsulateSubgraphsTest, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {"h_0_retval_retval", "H:o:0"}}); @@ -2214,7 +2260,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"D:o:0"}, @@ -2224,7 +2272,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O2"}}, + {"_outside_compilation_subgraph", "O2"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {}}, {{"outside_compilation_O3_host_compute"}, "XlaHostCompute", @@ -2235,7 +2285,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"key", "host_compute_channel_F1_O3"}, {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O3"}}, + {"_outside_compilation_subgraph", "O3"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {}}}, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {"h_0_retval_retval", "H:o:0"}}); @@ -2354,7 +2406,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {"f_0_retval_retval", "F:o:0"}}); @@ -2465,7 +2519,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"c"}}, }, {{"f_0_retval_retval", "F:o:0"}}); diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index 8b01768c49422b331b52a8ba31bade000c95722e..2a770c527b2fae91352fd17dacb13495a3a73f34 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" namespace tensorflow { @@ -308,6 +309,10 @@ xla::StatusOr BuildXlaHostComputeNodeDef( host_compute_builder.Attr("tpu_core", core); } + // Set input tokens. + host_compute_builder.Attr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + // Populate inputs. std::vector input_dtypes; TF_RETURN_IF_ERROR(GetNodeAttr(call_node->attrs(), "Tinputs", &input_dtypes)); @@ -398,8 +403,8 @@ Status ReplaceOrRemoveOutsideCompilationCallNode( } // Resets "device_ordinal" attr to placeholder value for related nodes -// (XlaRecvAtHost nodes; XlaSendFromHost nodes; If nodes containing -// XlaRecvAtHost/XlaSendFromHost). +// (XlaRecvAtHost nodes; XlaSendFromHost nodes; If/While/FuncCall nodes +// containing XlaRecvAtHost/XlaSendFromHost). Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { AttrValue device_ordinal_value; device_ordinal_value.set_placeholder("device_ordinal"); @@ -429,6 +434,10 @@ Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { n->ClearAttr(attr_name); n->AddAttr(attr_name, branch_func); } + } else if (HasNodeAttr(n->def(), "device_ordinal")) { + // Function call node containing outside compilation. + n->ClearAttr("device_ordinal"); + n->AddAttr("device_ordinal", device_ordinal_value); } else { return errors::Internal("Unknown node marked with ", kXlaHasHostTransferAttrName, ": ", @@ -1217,20 +1226,129 @@ Status BuildHostGraphForWhileNode( return Status::OK(); } +// Builds host graph for func call nodes. +Status BuildHostGraphForFuncCallNode(const string& func_call_node_name, + const string& xla_cluster_name, + const string& func_call_host_func_name, + const string& host_graph_func_name, + FunctionLibraryDefinition* fld) { + Graph host_graph(fld); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + + // Step 1: add key placeholder node. + TF_ASSIGN_OR_RETURN( + Node * key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); + + // Step 2: rewrite `host_func_name`, replace key placeholder with an _Arg + // node. + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, func_call_host_func_name, fld)); + + // Step 3: build a function call node with `host_func_name`, with + // `key_placeholder` as input. + NodeDefBuilder call_builder(absl::StrCat("oc_call_", func_call_node_name), + func_call_host_func_name, fld); + call_builder.Input(key_placeholder->name(), 0, DT_STRING); + call_builder.Attr("device_ordinal", device_ordinal_value); + call_builder.Attr(kXlaHasHostTransferAttrName, true); + NodeDef call_def; + TF_RETURN_IF_ERROR(call_builder.Finalize(&call_def)); + Status s; + Node* call_node = host_graph.AddNode(call_def, &s); + TF_RETURN_IF_ERROR(s); + host_graph.AddEdge(key_placeholder, 0, call_node, 0); + + // Convert `host_graph` to function, and add a "device_ordinal" attr. + FunctionDef oc_host_graph_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, + &oc_host_graph_fdef)); + if (fld->Find(host_graph_func_name)) { + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); + } + + return Status::OK(); +} + Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( Graph* g, const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, - const std::map& host_compute_core, + const std::map& host_compute_core, FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, std::vector* host_graphs, std::vector* shape_inference_graphs, bool* has_outside_compilation) { - std::vector if_nodes, while_nodes; + std::vector if_nodes, while_nodes, func_call_nodes; for (Node* n : g->nodes()) { if (n->type_string() == "If") { if_nodes.push_back(n); } else if (n->type_string() == "While") { while_nodes.push_back(n); + } else if (fld->Contains(n->type_string())) { + func_call_nodes.push_back(n); + } else if (n->type_string() == FunctionLibraryDefinition::kGradientOp) { + // Only gradient for user-defined function should be considered as + // function call node. + NameAttrList original_func; + TF_RETURN_IF_ERROR(GetNodeAttr( + n->def(), FunctionLibraryDefinition::kFuncAttr, &original_func)); + if (fld->Contains(original_func.name())) { + func_call_nodes.push_back(n); + } + } + } + + for (Node* n : func_call_nodes) { + // Extract outside compilation for the function call. + bool func_has_outside_compilation = false; + NameAttrList func; + func.set_name(n->type_string()); + typedef protobuf::Map AttrMap; + *func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end()); + string new_func_name = absl::StrCat(n->name(), "_oc"); + string host_func_name = absl::StrCat("oc_func_call_host_", n->name()); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + func, new_func_name, host_func_name, host_compute_core, flr, fld, + shape_inference_graphs, &func_has_outside_compilation)); + + // If the function call does not have outside compilation, nothing to do. + if (!func_has_outside_compilation) { + continue; } + + *has_outside_compilation = true; + + // Change `n` to call the new function directly. + NodeDefBuilder replace_builder(n->name(), new_func_name, fld); + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) { + continue; + } + replace_builder.Input(e->src()->name(), e->src_output(), + e->src()->output_type(e->src_output())); + } + for (const auto& attr : n->attrs()) { + replace_builder.Attr(attr.first, attr.second); + } + NodeDef replace_def; + TF_RETURN_IF_ERROR(replace_builder.Finalize(&replace_def)); + TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, replace_def)); + replace->AddAttr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + + // Build host side graph for the function call. + string oc_host_graph_name = + absl::StrCat("oc_func_host_graph_", replace->name()); + TF_RETURN_IF_ERROR( + BuildHostGraphForFuncCallNode(replace->name(), xla_cluster_name, + host_func_name, oc_host_graph_name, fld)); + + // Record the host graph. + host_graphs->push_back(oc_host_graph_name); } for (Node* n : if_nodes) { @@ -1251,12 +1369,12 @@ Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, then_branch, then_branch_xla_func_name, then_branch_host_func_name, - host_compute_core, fld, shape_inference_graphs, + host_compute_core, flr, fld, shape_inference_graphs, &then_branch_has_outside_compilation)); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, else_branch, else_branch_xla_func_name, else_branch_host_func_name, - host_compute_core, fld, shape_inference_graphs, + host_compute_core, flr, fld, shape_inference_graphs, &else_branch_has_outside_compilation)); // If then/else branch do not have outside compilation, nothing to do. @@ -1316,12 +1434,12 @@ Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( body_xla_func_name = absl::StrCat(body.name(), "_oc"); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, - cond, cond_xla_func_name, cond_host_func_name, host_compute_core, fld, - shape_inference_graphs, &cond_has_outside_compilation)); + cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr, + fld, shape_inference_graphs, &cond_has_outside_compilation)); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, - body, body_xla_func_name, body_host_func_name, host_compute_core, fld, - shape_inference_graphs, &body_has_outside_compilation)); + body, body_xla_func_name, body_host_func_name, host_compute_core, flr, + fld, shape_inference_graphs, &body_has_outside_compilation)); // If cond/body do not have outside compilation, nothing to do. if (!cond_has_outside_compilation && !body_has_outside_compilation) { @@ -1469,17 +1587,27 @@ Status ExtractOutsideCompilationForFunction( const string& outside_compilation_attr_name, const string& xla_cluster_name, const NameAttrList& func_name_attrs, const string& new_func_name, const string& host_graph_func_name, - const std::map& host_compute_core, + const std::map& host_compute_core, FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs, bool* has_outside_compilation) { + // Convert the function to graph. const string& func_name = func_name_attrs.name(); - const FunctionDef* fdef = fld->Find(func_name); - if (!fdef) { - return errors::Internal("Cannot find function ", func_name); - } + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR( + flr->Instantiate(func_name, AttrSlice(&func_name_attrs.attr()), &handle)); + Status ret_status = Status::OK(); + auto cleanup_handle = gtl::MakeCleanup([&]() { + auto s = flr->ReleaseHandle(handle); + if (!s.ok()) { + ret_status.Update(s); + } + }); + const FunctionBody* fbody = flr->GetFunctionBody(handle); + + // Check if we have outside compilation nodes. *has_outside_compilation = false; - for (auto& node_def : fdef->node_def()) { - if (HasNodeAttr(node_def, outside_compilation_attr_name)) { + for (Node* n : fbody->graph->nodes()) { + if (HasNodeAttr(n->def(), outside_compilation_attr_name)) { *has_outside_compilation = true; break; } @@ -1487,16 +1615,6 @@ Status ExtractOutsideCompilationForFunction( // We cannot early return here, because we might have outside compilation in // If/While function body. - // Convert the function to graph. - FunctionBody* fbody = nullptr; - TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(func_name), AttrSlice(&func_name_attrs.attr()), fld, - [&](const string& op, const OpDef** sig) { - return fld->LookUpOpDef(op, sig); - }, - &fbody)); - std::unique_ptr fbody_deleter(fbody); - // Preprocess edges between different outside compilations. They will be // restored in `ConstructHostGraph()`. TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations( @@ -1553,16 +1671,11 @@ Status ExtractOutsideCompilationForFunction( TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode( graph_out.get(), n, host_compute_core)); } - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("extract_outside_compilation_for_func_after_", func_name), - *graph_out, fld); - } // Handle nodes with associated functions. TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions( graph_out.get(), xla_cluster_attr_name, outside_compilation_attr_name, - xla_cluster_name, host_compute_core, fld, + xla_cluster_name, host_compute_core, flr, fld, &outside_compilation_host_graphs, shape_inference_graphs, has_outside_compilation)); @@ -1580,20 +1693,31 @@ Status ExtractOutsideCompilationForFunction( FunctionDef updated_fdef; TF_RETURN_IF_ERROR( GraphToFunctionDef(*graph_out, new_func_name, &updated_fdef)); + const FunctionDef* original_fdef = fld->Find(func_name); + if (original_fdef) { + for (const auto& attr : original_fdef->attr()) { + (*updated_fdef.mutable_attr())[attr.first] = attr.second; + } + } if (fld->Find(new_func_name)) { TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, updated_fdef)); } else { TF_RETURN_IF_ERROR(fld->AddFunctionDef(updated_fdef)); } + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("extract_outside_compilation_for_func_after_", func_name), + *graph_out, fld); + } - return Status::OK(); + return ret_status; } Status ExtractOutsideCompilation( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const std::unordered_map& clusters, Graph* g, - FunctionLibraryDefinition* fld) { + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld) { if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile("extract_outside_compilation_before", *g, fld); } @@ -1610,7 +1734,7 @@ Status ExtractOutsideCompilation( TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, func_name_attrs, func_name_attrs.name(), host_graph_func_name, - host_compute_core, fld, &shape_inference_graphs, + host_compute_core, flr, fld, &shape_inference_graphs, &has_outside_compilation)); TF_RETURN_IF_ERROR( ExpandHostGraphIntoMainGraph(g, fld, host_graph_func_name, n)); diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.h b/tensorflow/compiler/jit/extract_outside_compilation_pass.h index e07e7c5dd0cd42ddd4d643d8b36583c82056bbb2..d64cc2a103ed040cbf413ac736f97f84459e869b 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.h +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.h @@ -89,7 +89,7 @@ Status ExtractOutsideCompilationForFunction( const string& outside_compilation_attr_name, const string& xla_cluster_name, const NameAttrList& func_name_attrs, const string& new_func_name, const string& host_graph_func_name, - const std::map& host_compute_core, + const std::map& host_compute_core, FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs, bool* has_outside_compilation); @@ -101,7 +101,7 @@ Status ExtractOutsideCompilation( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const std::unordered_map& clusters, Graph* g, - FunctionLibraryDefinition* fld); + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc index e9a89e34e0c7b04b4be34e367b2d0bf627c0061a..7c3a24feff81b21a5d2347d21fb80988bc3e6065 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" @@ -31,6 +32,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { @@ -222,7 +225,42 @@ TEST(RewriteOutsideCompilationSubgraphFnTest, ShapesInferred) { EXPECT_EQ(shapes[0].dim_size(), 1); } -TEST(ExtractOutsideCompilationForFunctionTest, Basic) { +class ExtractOutsideCompilationForFunctionTest : public ::testing::Test { + public: + void SetUp() override { + SessionOptions session_options; + std::vector> devices; + TF_CHECK_OK(DeviceFactory::AddDevices( + session_options, "/job:localhost/replica:0/task:0", &devices)); + device_mgr_ = absl::make_unique(std::move(devices)); + } + + Status ExtractOutsideCompilationTest( + const string &xla_cluster_attr_name, + const string &outside_compilation_attr_name, + const string &xla_cluster_name, const NameAttrList &func_name_attrs, + const string &new_func_name, const string &host_graph_func_name, + const std::map &host_compute_core, + FunctionLibraryDefinition *fld, + std::vector *shape_inference_graphs, + bool *has_outside_compilation) { + OptimizerOptions opts; + pflr_ = absl::make_unique( + device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, fld, opts, + /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); + auto flr = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); + return ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + func_name_attrs, new_func_name, host_graph_func_name, host_compute_core, + flr, fld, shape_inference_graphs, has_outside_compilation); + } + + private: + std::unique_ptr device_mgr_; + std::unique_ptr pflr_; +}; + +TEST_F(ExtractOutsideCompilationForFunctionTest, Basic) { // Build the XLA computation func. // "const0" // "identity0" = "const0" (outside compilation cluster "0") @@ -256,7 +294,7 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; - TF_CHECK_OK(ExtractOutsideCompilationForFunction( + TF_CHECK_OK(ExtractOutsideCompilationTest( "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); @@ -362,7 +400,7 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { } } -TEST(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { +TEST_F(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { // Build the XLA computation func. // "const0" FunctionDefLibrary fdl; @@ -384,7 +422,7 @@ TEST(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; - TF_CHECK_OK(ExtractOutsideCompilationForFunction( + TF_CHECK_OK(ExtractOutsideCompilationTest( "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); @@ -406,7 +444,7 @@ TEST(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { EXPECT_EQ(host_graph->num_nodes(), 2); } -TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { +TEST_F(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { // Build the XLA computation func. // "const0" // "const1" (outside compilation cluster "0") @@ -432,7 +470,7 @@ TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; - TF_CHECK_OK(ExtractOutsideCompilationForFunction( + TF_CHECK_OK(ExtractOutsideCompilationTest( "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); @@ -489,7 +527,7 @@ REGISTER_OP("XlaRecvFromHost") .Attr("key: string") .SetIsStateful(); -TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { +TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { // Build the XLA computation func. // "const0" (bool) // "const1" (int32) @@ -555,7 +593,7 @@ TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; - TF_CHECK_OK(ExtractOutsideCompilationForFunction( + TF_CHECK_OK(ExtractOutsideCompilationTest( "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); @@ -651,7 +689,7 @@ TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { } } -TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) { +TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) { // Build the XLA computation func. // "const0" (bool) // "while0" (input = "const0", cond = "cond_fn", body = "body_fn") @@ -714,7 +752,7 @@ TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) { NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; - TF_CHECK_OK(ExtractOutsideCompilationForFunction( + TF_CHECK_OK(ExtractOutsideCompilationTest( "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); @@ -782,4 +820,162 @@ TEST(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) { } } +TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) { + // Build the XLA computation func. + // "const0" (int32) + // "fn" (input = "const0") + FunctionDefLibrary fdl; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0); + Output identity = ops::Identity(s.WithOpName("identity"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *true_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "fn", true_fn_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + { + std::unique_ptr g(new Graph(&fld)); + + tensorflow::TensorProto tensor_proto; + tensor_proto.set_dtype(tensorflow::DT_INT32); + tensorflow::TensorShapeProto shape; + shape.add_dim()->set_size(2); + *tensor_proto.mutable_tensor_shape() = shape; + for (int i = 0; i < 2; ++i) { + tensor_proto.add_int_val(1); + } + NodeDef const_def; + TF_CHECK_OK(NodeDefBuilder("const", "Const") + .Attr("dtype", DT_INT32) + .Attr("value", tensor_proto) + .Finalize(&const_def)); + Status s; + Node *const_node = g->AddNode(const_def, &s); + TF_CHECK_OK(s); + + NodeDef fn_def; + TF_CHECK_OK(NodeDefBuilder("fn", "fn", &fld) + .Input("const", 0, DT_INT32) + .Finalize(&fn_def)); + Node *fn_node = g->AddNode(fn_def, &s); + TF_CHECK_OK(s); + g->AddEdge(const_node, 0, fn_node, 0); + + NodeDef ret_def; + TF_CHECK_OK(NodeDefBuilder("ret", "_Retval") + .Attr("index", 0) + .Attr("T", DT_INT32) + .Input("fn", 0, DT_INT32) + .Finalize(&ret_def)); + Node *ret_node = g->AddNode(ret_def, &s); + TF_CHECK_OK(s); + g->AddEdge(fn_node, 0, ret_node, 0); + + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); + TF_CHECK_OK(fld.AddFunctionDef(*xla_fdef)); + } + + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; + bool has_outside_compilation; + NameAttrList name_attrs; + name_attrs.set_name("cluster"); + *name_attrs.mutable_attr() = attrs; + TF_CHECK_OK(ExtractOutsideCompilationTest( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, + &has_outside_compilation)); + + // Check host graph. + { + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; + auto node_name_index = host_graph->BuildNodeNameIndex(); + + // Verify we have call node for outside compilation in `fn`. + Node *call_node = node_name_index["oc_call_fn"]; + EXPECT_NE(call_node, nullptr); + + FunctionBody *call_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("oc_func_call_host_fn"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &call_fbody)); + std::unique_ptr call_fbody_deleter(call_fbody); + + // Verify we have _XlaRecvAtHost and _XlaSendFromHost nodes. + bool has_recv = false, has_send = false; + for (Node *n : call_fbody->graph->nodes()) { + if (n->type_string() == "_XlaRecvAtHost") { + has_recv = true; + } else if (n->type_string() == "_XlaSendFromHost") { + has_send = true; + } + } + EXPECT_TRUE(has_recv); + EXPECT_TRUE(has_send); + } + + // Check XLA graph. + { + FunctionBody *xla_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("cluster_rewritten"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &xla_fbody)); + std::unique_ptr xla_fbody_deleter(xla_fbody); + Graph *xla_graph = xla_fbody->graph; + auto node_name_index = xla_graph->BuildNodeNameIndex(); + + // Check that we have call node. + Node *fn_node = node_name_index["fn"]; + EXPECT_NE(fn_node, nullptr); + EXPECT_EQ(fn_node->type_string(), "fn_oc"); + + FunctionBody *call_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("fn_oc"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &call_fbody)); + std::unique_ptr call_fbody_deleter(call_fbody); + + // Verify we have XlaHostCompute nodes. + bool has_hc = false; + for (Node *n : call_fbody->graph->nodes()) { + if (n->type_string() == "XlaHostCompute") { + has_hc = true; + } + } + EXPECT_TRUE(has_hc); + } +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 98e344b3a080aa8aab27cd41564a90427bac151e..fba69dfccc31e01e73d8f86006b41ce5e3283f15 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -68,7 +68,12 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { Flag("tf_xla_fusion_only", &mark_for_compilation_flags->tf_xla_fusion_only, "enable fusion of element-wise operations only using XLA when " - "global_jit_level is ON*.")}; + "global_jit_level is ON*."), + Flag("tf_xla_disable_deadness_safety_checks_for_debugging", + &mark_for_compilation_flags + ->tf_xla_disable_deadness_safety_checks_for_debugging, + "Disable deadness related safety checks when clustering (this is " + "unsound).")}; flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end()); } @@ -89,6 +94,8 @@ void AllocateAndParseFlags() { mark_for_compilation_flags->tf_xla_clustering_fuel = std::numeric_limits::max(); mark_for_compilation_flags->tf_xla_fusion_only = false; + mark_for_compilation_flags + ->tf_xla_disable_deadness_safety_checks_for_debugging = false; device_flags = new XlaDeviceFlags; device_flags->tf_xla_compile_on_demand = false; diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 5ddea588eef5270880d91623dc05893da265960a..ed7810fcfd85c17db70d42e691446b60dc696939 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -25,27 +25,39 @@ namespace tensorflow { // Flags associated with the XLA bridge's mark_for_compilation_pass module. struct MarkForCompilationPassFlags { - int32 tf_xla_auto_jit; // Control compilation of operators into XLA - // computations on CPU and GPU devices. 0 = use - // ConfigProto setting; -1 = off; 1 = on for things - // very likely to be improved; 2 = on for everything. - // Experimental. - int32 tf_xla_min_cluster_size; // Minimum number of operators in an XLA - // compilation. Ignored for operators placed - // on an XLA device or operators explicitly - // marked for compilation. - int32 tf_xla_max_cluster_size; // Maximum number of operators in an XLA - // compilation. - bool tf_xla_clustering_debug; // Dump graphs during XLA compilation. - bool tf_xla_cpu_global_jit; // Enables global JIT compilation for CPU - // via SessionOptions. - int64 tf_xla_clustering_fuel; // "Compiler fuel" for clustering. Only this - // many ops will be marked as eligible for - // clustering. - bool tf_xla_fusion_only; // This flag is effective only when global_jit_level - // is set to ON* and overrides its behavior. If - // true, enable fusion of element-wise operations - // only using XLA. + // Control compilation of operators into XLA computations on CPU and GPU + // devices. 0 = use ConfigProto setting; -1 = off; 1 = on for things very + // likely to be improved; 2 = on for everything. + // + // Experimental. + int32 tf_xla_auto_jit; + + // Minimum number of operators in an XLA compilation. Ignored for operators + // placed on an XLA device or operators explicitly marked for compilation. + int32 tf_xla_min_cluster_size; + + // Maximum number of operators in an XLA compilation. + int32 tf_xla_max_cluster_size; + + // Dump graphs during XLA compilation. + bool tf_xla_clustering_debug; + + // Enables global JIT compilation for CPU via SessionOptions. + bool tf_xla_cpu_global_jit; + + // "Compiler fuel" for clustering. Only this many ops will be marked as + // eligible for clustering. + int64 tf_xla_clustering_fuel; + + // tf_xla_fusion_only is effective only when global_jit_level is set to ON* + // and overrides its behavior. If true, enable fusion of element-wise + // operations only using XLA. + bool tf_xla_fusion_only; + + // If tf_xla_disable_deadness_safety_checks_for_debugging is set to true then + // we do not do deadness related safety checks. This is unsound in general, + // but can be used as a debugging aid. + bool tf_xla_disable_deadness_safety_checks_for_debugging; }; // Flags associated with the XLA bridge's xla_device module. diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 6618e3a58ab7b6374ed775cd6e4e18a6a4975588..20c2cd7e0561f92a01486102c4d2c572fd80c957 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -41,7 +42,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" -#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" @@ -677,12 +678,28 @@ Status MarkForCompilationPass::Run( VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit; const FunctionLibraryDefinition* fld = options.flib_def; + // Deadness analysis expects a graph with source and sink edges properly + // connected but sometimes the incoming graph does not follow this invariant. + // So fix up the source and sink edges before calling into deadness analysis. + FixupSourceAndSinkEdges(options.graph->get()); + std::unique_ptr deadness; { XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1); TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(**options.graph, &deadness)); } + bool deadness_analysis_disabled = + GetMarkForCompilationPassFlags() + ->tf_xla_disable_deadness_safety_checks_for_debugging; + + if (deadness_analysis_disabled) { + LOG(WARNING) << "Deadness analysis was manually disabled via " + "--tf_xla_disable_deadness_safety_checks_for_debugging; " + "auto-clustering " + "is unsound!"; + } + auto is_compilable = [&](const Node* node, const DeviceType& device_type) { const XlaOpRegistry::DeviceRegistration* registration; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), @@ -715,9 +732,12 @@ Status MarkForCompilationPass::Run( // and some are dead) then don't compile it. XLA cannot represent the // deadness semantics of these nodes correctly and auto-clustering these // nodes can cause deadness to propagate to nodes that should be live. - if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) { - VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness."; - return false; + if (!deadness_analysis_disabled) { + if (node->IsMerge() || + deadness->HasInputsWithMismatchingDeadness(*node)) { + VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness."; + return false; + } } // Check for fusable ops only if requested. @@ -1145,6 +1165,27 @@ Status MarkForCompilationPass::RunImpl( if (flags->tf_xla_clustering_debug) { dump_graph::DumpGraphToFile("mark_for_compilation", **options.graph, options.flib_def); + + // We also dump out an annoated version of the TF graph where the nodes + // names are prefixed with the cluster names. This can help visualizing the + // clustering decisions on TensorBoard. + Graph new_graph((*options.graph)->op_registry()); + CopyGraph(**options.graph, &new_graph); + + for (Node* n : new_graph.nodes()) { + if (absl::optional cluster_name = + GetXlaClusterForNode(*n)) { + n->set_name(absl::StrCat(*cluster_name, "/", n->name())); + } else { + // There is room for improvement here. In particular, it may help to + // split these unclustered nodes into classes where every node in a + // specific class has edges to and from the same set of clusters. + n->set_name(absl::StrCat("unclustered/", n->name())); + } + } + + dump_graph::DumpGraphToFile("mark_for_compilation_annotated", new_graph, + options.flib_def); } VLogClusteringSummary(*graph); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index bf2c5508ea9e987e80093f4c2e15d3ff5191126f..c2b6250f738fafa35b2c5f79e97cf1281b50a316 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -151,7 +151,7 @@ TEST(XlaCompilationTest, CompilableCycles) { EXPECT_EQ(clusters["A"], clusters["C"]); } -TEST(XlaCompilationTest, Complex128Unsupported) { +TEST(XlaCompilationTest, StringUnsupported) { std::unique_ptr graph(new Graph(OpRegistry::Global())); GraphDef graphdef; { @@ -159,10 +159,10 @@ TEST(XlaCompilationTest, Complex128Unsupported) { Node* a = ops::SourceOp( "Const", builder.opts() .WithName("A") - .WithAttr("dtype", DT_COMPLEX128) - .WithAttr("value", Tensor(DT_COMPLEX128, TensorShape()))); - Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); - ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); + .WithAttr("dtype", DT_STRING) + .WithAttr("value", Tensor(DT_STRING, TensorShape()))); + Node* b = ops::UnaryOp("EncodeBase64", a, builder.opts().WithName("B")); + ops::BinaryOp("StringSplit", a, b, builder.opts().WithName("C")); TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index 38a54cc5efae35ad77b6dc8039c653e920cfc071..1d81a8f4fcbf050663626b1f7660afd71f4027bc 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" -#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index fef28fc810cb4e544fe3f271f0b96cebd8a96779..80993861abba050fa3d6a133023d3c99f41f73e3 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -19,9 +19,9 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/control_flow.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 3df5479a55e841380ca7b8cdd0add9fd17487091..611515cf33bc1abe21e06eb7f1513800276e095b 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -38,6 +39,8 @@ limitations under the License. namespace tensorflow { +constexpr int64 XlaCompilationCache::kDefaultCompilationThreshold; + XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client, DeviceType device_type) : client_(client), device_type_(std::move(device_type)) {} @@ -60,7 +63,7 @@ XlaCompilationCache::~XlaCompilationCache() { // about? } -string XlaCompilationCache::DebugString() { +string XlaCompilationCache::DebugString() const { return "XLA JIT compilation cache"; } @@ -68,9 +71,9 @@ string XlaCompilationCache::DebugString() { // arguments in the supplied list. string XlaCompilationCache::Signature::HumanString() const { string result = name; - for (const auto& a : arg_types) { - absl::StrAppend(&result, ",", DataTypeString(a.first), - a.second.DebugString()); + for (const auto& a : arg_shapes) { + absl::StrAppend(&result, ",", DataTypeString(a.first)); + absl::StrAppend(&result, " [", absl::StrJoin(a.second, ","), "]"); } for (const auto& v : arg_values) { @@ -81,7 +84,7 @@ string XlaCompilationCache::Signature::HumanString() const { bool XlaCompilationCache::Signature::operator==(const Signature& other) const { if (name != other.name) return false; - if (arg_types != other.arg_types) return false; + if (arg_shapes != other.arg_shapes) return false; if (arg_values.size() != other.arg_values.size()) return false; for (int i = 0; i < arg_values.size(); ++i) { @@ -97,10 +100,10 @@ bool XlaCompilationCache::Signature::operator==(const Signature& other) const { uint64 XlaCompilationCache::Signature::Hash::operator()( const XlaCompilationCache::Signature& signature) const { uint64 h = std::hash()(signature.name); - for (const auto& arg : signature.arg_types) { + for (const auto& arg : signature.arg_shapes) { h = Hash64Combine(h, std::hash()(static_cast(arg.first))); - h = Hash64Combine(h, std::hash()(arg.second.dims())); - for (int dim : arg.second.dim_sizes()) { + h = Hash64Combine(h, std::hash()(arg.second.size())); + for (int dim : arg.second) { h = Hash64Combine(h, std::hash()(dim)); } } @@ -124,7 +127,7 @@ XlaCompilationCache::BuildSignature( break; case XlaCompiler::Argument::kParameter: case XlaCompiler::Argument::kResource: - signature.arg_types.emplace_back(arg.type, arg.shape); + signature.arg_shapes.emplace_back(arg.type, arg.DimensionSizes()); break; default: return errors::InvalidArgument( diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 846d0c963dbfdf55f51120f2f138d12f5f63839b..7748b4700f39da4f952278ca6c6d2cadff4d3fb8 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -88,14 +88,16 @@ class XlaCompilationCache : public ResourceBase { xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } - string DebugString() override; + string DebugString() const override; // Describes the types, shapes and any compile-time constant arguments // to a kernel. Key that uniquely identifies a compilation output. struct Signature { string name; - std::vector> arg_types; + // List of Tensor types & shapes for compile-time constant arguments to the + // compilation, ordered by argument number. + std::vector>> arg_shapes; // List of Tensor values for compile-time constant arguments to the // compilation, ordered by argument number. Tensors must be in host memory. diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index e9770647e7ba96cc1db026d12d5f11f52ce98d35..94dc61d55fb047c0ea81d98fde24cb55387c27d7 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -83,9 +83,9 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaCpuTypes = { +constexpr std::array kAllXlaCpuTypes = { {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64, - DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; + DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 77cd2f44628677942da9e576070d1d295194cead..e2397f6fcb8677f4bd5151646f9ebacd3e23af5b 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -219,9 +219,6 @@ XlaDevice::XlaDevice(const SessionOptions& session_options, XlaDevice::~XlaDevice() { VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this; mutex_lock lock(mu_); - while (outstanding_asynchronous_operations_ > 0) { - outstanding_asynchronous_operations_cv_.wait(lock); - } if (device_context_) { device_context_->Unref(); } @@ -398,12 +395,6 @@ Status XlaDevice::Sync() { if (!stream) return Status::OK(); Status status = stream->BlockHostUntilDone(); - { - mutex_lock lock(mu_); - while (outstanding_asynchronous_operations_ > 0) { - outstanding_asynchronous_operations_cv_.wait(lock); - } - } TF_RETURN_IF_ERROR(status); if (!stream->ok()) { return errors::Internal("XlaDevice::Sync() failed."); @@ -412,6 +403,8 @@ Status XlaDevice::Sync() { return Status::OK(); } +// TODO(b/112409994): This is no longer necessary. Consolidate it with the +// synchronous version. void XlaDevice::Sync(const DoneCallback& done) { VLOG(1) << "XlaDevice::Sync (asynchronous)"; std::shared_ptr stream; @@ -424,14 +417,20 @@ void XlaDevice::Sync(const DoneCallback& done) { return; } + // The call to ThenEnqueueOnBackgroundThread below enqueues a host callback at + // the end of the stream, after everything that has already been enqueued + // there at this moment. When the host callback is called, everything before + // it must have already finished, and the host callback will then place the + // task below onto a background thread. (See the implementation of + // ThenEnqueueOnBackgroundThread for details.) Therefore, when the done + // callback is finally called from that background thread, we know for sure + // that everything enqueued onto the stream (i.e., the device) at this very + // moment--when ThenEnqueueOnBackgroundThread is called--will have finished. + // This achieves a device-wide sync. stream->ThenEnqueueOnBackgroundThread( [this, stream, done](se::StreamExecutor*) { tracing::ScopedActivity activity("XlaDevice::Sync::Callback", /*is_expensive=*/true); - mutex_lock lock(mu_); - while (outstanding_asynchronous_operations_ > 0) { - outstanding_asynchronous_operations_cv_.wait(lock); - } done(stream->ok() ? Status::OK() : errors::Internal("XlaDevice::Sync() failed.")); }); @@ -470,57 +469,26 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, return status; } -void XlaDevice::SetRequiresSyncOnCompletion(bool sync_on_completion) { +void XlaDevice::SetAllowsSyncOnCompletion(bool sync_on_completion) { mutex_lock lock(mu_); sync_on_completion_ = sync_on_completion; } -bool XlaDevice::RequiresSyncOnCompletion() const { +bool XlaDevice::AllowsSyncOnCompletion() const { mutex_lock lock(mu_); return sync_on_completion_; } -XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( - XlaDevice* device) - : device_(device) { - mutex_lock lock(device_->mu_); - ++device_->outstanding_asynchronous_operations_; -} - -XlaDevice::AsynchronousOperationHandle::~AsynchronousOperationHandle() { - if (device_) { - mutex_lock lock(device_->mu_); - --device_->outstanding_asynchronous_operations_; - device_->outstanding_asynchronous_operations_cv_.notify_all(); +Status XlaDevice::CurrentStatus() { + std::shared_ptr stream; + { + mutex_lock lock(mu_); + stream = stream_; } -} - -XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( - const XlaDevice::AsynchronousOperationHandle& other) - : device_(other.device_) { - mutex_lock lock(device_->mu_); - ++device_->outstanding_asynchronous_operations_; -} - -XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( - XlaDevice::AsynchronousOperationHandle&& other) - : device_(other.device_) { - other.device_ = nullptr; -} - -XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle:: -operator=(const XlaDevice::AsynchronousOperationHandle& other) { - device_ = other.device_; - mutex_lock lock(device_->mu_); - ++device_->outstanding_asynchronous_operations_; - return *this; -} - -XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle:: -operator=(XlaDevice::AsynchronousOperationHandle&& other) { - device_ = other.device_; - other.device_ = nullptr; - return *this; + if (!stream) { + return Status::OK(); + } + return stream->ok() ? Status::OK() : errors::Internal("XlaDevice is not OK."); } XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 45f18ac9ee6d403c192bd421d7823f2d408d994b..e35a1c7d29514dc5777bdbd3858c56401d7b9044 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -167,35 +167,14 @@ class XlaDevice : public LocalDevice { Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_); // Instructs this XlaDevice to return 'sync_on_completion' for - // RequiresSyncOnCompletion(). - void SetRequiresSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_); + // AllowsSyncOnCompletion(). + void SetAllowsSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_); - bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); + bool AllowsSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); - // A simple RAII handle. On construction the device's - // outstanding_asynchronous_operations_ field is incremented; on destruction - // it is decremented. - class AsynchronousOperationHandle { - public: - AsynchronousOperationHandle(XlaDevice* device); - ~AsynchronousOperationHandle(); - AsynchronousOperationHandle(const AsynchronousOperationHandle& other); - AsynchronousOperationHandle(AsynchronousOperationHandle&& other); - AsynchronousOperationHandle& operator=( - const AsynchronousOperationHandle& other); - AsynchronousOperationHandle& operator=(AsynchronousOperationHandle&& other); - - private: - XlaDevice* device_ = nullptr; - }; - - AsynchronousOperationHandle CreateAsynchronousOperationHandle() { - return AsynchronousOperationHandle(this); - } + Status CurrentStatus() override LOCKS_EXCLUDED(mu_); private: - friend class AsynchronousOperationHandle; - xla::LocalClient* client() const; Allocator* GetAllocatorLocked(AllocatorAttributes attr) EXCLUSIVE_LOCKS_REQUIRED(mu_); @@ -255,14 +234,9 @@ class XlaDevice : public LocalDevice { // Thread pool used for running closures std::unique_ptr thread_pool_; - // True if the device requires XlaDevice::Sync to be called on completion + // True if the device allows XlaDevice::Sync to be called on completion // regardless of status. - bool sync_on_completion_ GUARDED_BY(mu_) = false; - - // Count of outstanding asynchronous operations which must be zero on Sync() - // completion. - int64 outstanding_asynchronous_operations_ GUARDED_BY(mu_) = 0; - condition_variable outstanding_asynchronous_operations_cv_; + bool sync_on_completion_ GUARDED_BY(mu_) = true; // Set of devices to use. This controls which of the devices on the given // platform will have resources allocated. For GPUs this will be diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 1f3afe8822d441a5ce37617fe18d7767e9bc72e4..28681bb8b03dbf97e8145972f9a04b5855fafdae 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -131,7 +131,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, xla::ShapeUtil::MakeShape(shape.element_type(), xla::AsInt64Slice(shape.dimensions()))); - VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " " + VLOG(2) << "Transfer to device as literal: " << literal.ToString() << " " << xla_tensor->shaped_buffer().ToString(); if (UseMultipleStreams() && !transfer_manager_->CanShapedBufferBeAccessedNow( @@ -214,7 +214,7 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, device_to_host_stream_.get(), xla_tensor->shaped_buffer(), literal, [ref, xla_tensor, done](xla::Status status) { done([&]() -> Status { - VLOG(1) << "Transfer from device as literal: " + VLOG(2) << "Transfer from device as literal: " << xla_tensor->shaped_buffer().ToString(); return status; }()); diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 4007309ed1c57b663dca5bac0df11260bf1327f3..e1a582406153d2af447fa9d4ebcaf0bf0842b132 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -26,9 +26,9 @@ namespace tensorflow { const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER"; const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT"; -constexpr std::array kExecAllTypes = { +constexpr std::array kExecAllTypes = { {DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, - DT_BOOL, DT_BFLOAT16}}; + DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; class XlaInterpreterDeviceFactory : public DeviceFactory { public: diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 3b0bda4caa161a7561a3098b89420329998ff8a7..c64981053fad2dbf1e8bcd623a940ded8b4d9150 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -237,7 +237,7 @@ void XlaComputationLaunchContext::PopulateInputs( const xla::Shape on_device_shape = client_->backend().transfer_manager()->HostShapeToDeviceShape(shape); - if (xla::ShapeUtil::IsTuple(on_device_shape)) { + if (on_device_shape.IsTuple()) { const XlaTensor* xla_tensor = XlaTensor::FromTensor(t); CHECK(xla_tensor && xla_tensor->has_shaped_buffer()); arg_ptrs_[i] = const_cast(&xla_tensor->shaped_buffer()); @@ -274,7 +274,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( // If the on-host-shape isn't a tuple, create a new single-element tuple // buffer with a nullptr root index table. This allows the code below to treat // output as a tuple unconditionally. - if (!xla::ShapeUtil::IsTuple(output.on_host_shape())) { + if (!output.on_host_shape().IsTuple()) { ShapedBuffer nontuple_buffer = output.release(); ShapedBuffer buffer( xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}), @@ -377,7 +377,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( } if (VLOG_IS_ON(3)) { - VLOG(3) << ctx->mutable_output(i)->DebugString(); + VLOG(3) << ctx->mutable_output(i)->DeviceSafeDebugString(); } } diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index fa02cf9cbef45188a6dc2f861ff036649ea92b03..2b9f5d8dbd5152c74936ca92b1066760c4caa00f 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -230,6 +230,7 @@ tf_xla_py_test( "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", + "//tensorflow/python:standard_ops", ], ) @@ -406,7 +407,7 @@ tf_xla_py_test( tf_xla_py_test( name = "eager_test", - size = "large", + size = "medium", srcs = ["eager_test.py"], deps = [ ":xla_test", @@ -677,6 +678,7 @@ tf_xla_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", + "//tensorflow/python:standard_ops", ], ) @@ -826,6 +828,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python:standard_ops", "//tensorflow/python:stateless_random_ops", ], ) diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 9a5423c1b2a5df7880453cbb328f6a8174066255..c829c50b5518b29c96c0b0117a6cd143911bd1fc 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -311,6 +311,30 @@ class BinaryOpsTest(xla_test.XLATestCase): dtype(7), expected=np.array([[-6], [-5]], dtype=dtype)) + if dtype in [np.float32, np.float64]: + x = np.array([ + -0.0, 0.0, -0.0, +0.0, np.inf, np.inf, -np.inf, -np.inf, 2.0, 2.0, + 1.0 + ], + dtype=dtype) + y = np.array( + [-0.0, 0.0, +0.0, -0.0, 1.0, -1.0, 1.0, -1.0, 2.0, 1.0, 2.0], + dtype=dtype) + expected = np.nextafter(x, y) + + # We use assertAllEqual to expose any bugs hidden by relative or + # absolute error tolerances. + def NextAfterEqualityTest(result, expected, rtol): + del rtol + return self.assertAllEqual(result, expected) + + self._testBinary( + math_ops.nextafter, + x, + y, + expected=expected, + equality_test=NextAfterEqualityTest) + # min/max not supported for complex if dtype not in self.complex_types | {np.uint8, np.int8}: self._testBinary( @@ -400,7 +424,7 @@ class BinaryOpsTest(xla_test.XLATestCase): def testComplexOps(self): for dtype in self.complex_types: - ctypes = {np.complex64: np.float32} + ctypes = {np.complex64: np.float32, np.complex128: np.float64} self._testBinary( math_ops.complex, np.array([[[[-1, 2], [2, 0]]]], dtype=ctypes[dtype]), diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 447a7de2cb6526a5dcf7789d4f2bffb5e733e8c0..ed580f95b6c2f57dfdf46cfcd64cabb452980c5d 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -5,6 +5,7 @@ load("//tensorflow/compiler/tests:plugin.bzl", "plugins") load( "//tensorflow/core:platform/default/build_config_root.bzl", "tf_cuda_tests_tags", + "tf_exec_compatible_with", ) def all_backends(): @@ -64,7 +65,7 @@ def tf_xla_py_test( if backend == "cpu": backend_args += [ "--test_device=XLA_CPU", - "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64", + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_COMPLEX128", ] elif backend == "gpu": backend_args += [ @@ -84,6 +85,7 @@ def tf_xla_py_test( else: fail("Unknown backend {}".format(backend)) + test_tags = tags + backend_tags native.py_test( name = test_name, srcs = srcs, @@ -92,7 +94,8 @@ def tf_xla_py_test( main = "{}.py".format(name) if main == None else main, data = data + backend_data, deps = deps + backend_deps, - tags = tags + backend_tags, + tags = test_tags, + exec_compatible_with = tf_exec_compatible_with({"tags": test_tags}), **kwargs ) test_names.append(test_name) diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index bf5ea7b1fb6fb3c774c4db20d059f131990d20d3..b7d08df9f7d144b71fd0b09535e10b8f596ea6ca 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -72,7 +72,7 @@ class DenseLayerTest(test.TestCase): x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) y = layers.dense(x, 3) - self.evaluate(variables.initialize_all_variables()) + self.evaluate(variables.global_variables_initializer()) run_metadata = config_pb2.RunMetadata() test_utils.RunWithWarmup( sess, @@ -97,7 +97,7 @@ class DenseLayerTest(test.TestCase): with jit_scope(): y = layers.dense(x, 3) - self.evaluate(variables.initialize_all_variables()) + self.evaluate(variables.global_variables_initializer()) run_metadata = config_pb2.RunMetadata() test_utils.RunWithWarmup( sess, @@ -126,7 +126,7 @@ class DenseLayerTest(test.TestCase): with jit_scope(): y = layers.dense(x, 3) - self.evaluate(variables.initialize_all_variables()) + self.evaluate(variables.global_variables_initializer()) run_metadata = config_pb2.RunMetadata() test_utils.RunWithWarmup( sess, diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 2af32b537ba53723370faf81aebf308a465718c7..c9fce39f6c5111f93a54708b59b4c42c3ba844b6 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -24,6 +24,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -31,6 +32,7 @@ from tensorflow.python.framework import ops from tensorflow.python.layers import convolutional from tensorflow.python.layers import pooling from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import gen_random_ops from tensorflow.python.ops import init_ops @@ -463,7 +465,7 @@ class EagerFunctionTest(xla_test.XLATestCase): def f(x, y): return x[0::2, y:, ...] - x = array_ops.ones([2, 3, 4]) + x = array_ops.ones([2, 3, 4], dtype=dtypes.float32) y = array_ops.ones([], dtype=dtypes.int32) with backprop.GradientTape() as tape: tape.watch(x) @@ -479,15 +481,15 @@ class EagerFunctionTest(xla_test.XLATestCase): @function.defun def times_two(x): - return 2 * x + return 2. * x @function.defun def two_x_plus_1(x): - return times_two(x) + 1 + return times_two(x) + 1. - x = constant_op.constant([2, 3, 4]) + x = constant_op.constant([2., 3., 4.]) y = two_x_plus_1(x) - self.assertAllEqual([5, 7, 9], y.numpy()) + self.assertAllEqual([5., 7., 9.], y.numpy()) def testNestedDefunWithVariable(self): with self.test_scope(): @@ -506,7 +508,7 @@ class EagerFunctionTest(xla_test.XLATestCase): x = constant_op.constant(3.0) y = f(x) - self.assertEqual(75, y.numpy()) + self.assertEqual(75.0, y.numpy()) def testNestedDefunInGradientTape(self): with self.test_scope(): @@ -555,6 +557,56 @@ class EagerFunctionTest(xla_test.XLATestCase): self.assertEqual(9, dy_v0.numpy()) self.assertEqual(15, dy_v1.numpy()) + def testWhileInDefun(self): + with self.test_scope(): + @def_function.function + def f(start): + c = lambda x: math_ops.less(x, 13.0) + b = lambda x: math_ops.add(x, 1.0) + return control_flow_ops.while_loop(c, b, [start]) + + y = f(constant_op.constant(3.0)) + self.assertEqual(13.0, y.numpy()) + + def testAutoGraphWhileInDefun(self): + with self.test_scope(): + @def_function.function + def f(start): + x = start + while x < 13.0: + x += 1.0 + return x + + y = f(constant_op.constant(3.0)) + self.assertEqual(13.0, y.numpy()) + + def testCondInDefun(self): + with self.test_scope(): + @def_function.function + def f(pred, value): + fn1 = lambda: math_ops.add(value, 1.0) + fn2 = lambda: math_ops.subtract(value, 1.0) + return control_flow_ops.cond(pred, fn1, fn2) + + plus_one = f(constant_op.constant(True), constant_op.constant(10.0)) + minus_one = f(constant_op.constant(False), constant_op.constant(10.0)) + self.assertEqual(11.0, plus_one.numpy()) + self.assertEqual(9.0, minus_one.numpy()) + + def testAutoGraphCondInDefun(self): + with self.test_scope(): + @def_function.function + def f(pred, value): + if pred: + return value + 1.0 + else: + return value - 1.0 + + plus_one = f(constant_op.constant(True), constant_op.constant(10.0)) + minus_one = f(constant_op.constant(False), constant_op.constant(10.0)) + self.assertEqual(11.0, plus_one.numpy()) + self.assertEqual(9.0, minus_one.numpy()) + class ExcessivePaddingTest(xla_test.XLATestCase): """Test that eager execution works with TPU flattened tensors. diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 0e2d840418156d825e2d141018e49f42374c8fee..42e688174fce9e939feb09e1767ebab31e30a6ee 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -403,6 +403,117 @@ class AdjustSaturationTest(xla_test.XLATestCase): self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5) +class ResizeNearestNeighborTest(xla_test.XLATestCase): + # TODO(ilch): Wrap each test with `for dtype in self.float_types:` + # Some work to understand how that should be done was presented here: + # cl/227850213 + + def _assertForwardOpMatchesExpected(self, + image_np, + target_shape, + expected=None, + large_tolerance=False, + align_corners=True): + if expected is None: + self.fail("expected must be specified") + with self.cached_session() as sess, self.test_scope(): + image = array_ops.placeholder(image_np.dtype) + resized = gen_image_ops.resize_nearest_neighbor( + image, target_shape, align_corners=align_corners) + out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]}) + if large_tolerance: + self.assertAllClose( + expected[np.newaxis, :, :, np.newaxis], out, rtol=2e-4, atol=2e-4) + else: + self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) + + def testAlignCorners2x2To1x1(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2], [3, 4]], dtype=np.float32), [1, 1], + expected=np.array([[1]], dtype=np.float32)) + + def testAlignCorners1x1To2x2(self): + self._assertForwardOpMatchesExpected( + np.array([[1]], dtype=np.float32), [2, 2], + expected=np.array([[1, 1], [1, 1]], dtype=np.float32)) + + def testAlignCorners1x1To3x3(self): + self._assertForwardOpMatchesExpected( + np.array([[1]], dtype=np.float32), [3, 3], + expected=np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]], dtype=np.float32)) + + def testAlignCorners2x2To3x3(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2], [3, 4]], dtype=np.float32), [3, 3], + expected=np.array([[1, 2, 2], [3, 4, 4], [3, 4, 4]], dtype=np.float32)) + + def testAlignCorners2x2To4x4(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2], [3, 4]], dtype=np.float32), [4, 4], + expected=np.array( + [[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]], + dtype=np.float32), large_tolerance=True) + + def testAlignCorners3x3To2x2(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), [2, 2], + expected=np.array([[1, 3], [7, 9]], dtype=np.float32)) + + def testAlignCorners4x4To3x3(self): + self._assertForwardOpMatchesExpected( + np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + dtype=np.float32), [3, 3], + expected=np.array([[1, 3, 4], [9, 11, 12], [13, 15, 16]], + dtype=np.float32)) + + def testAlignCorners3x3To4x4(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), [4, 4], + expected=np.array( + [[1, 2, 2, 3], [4, 5, 5, 6], [4, 5, 5, 6], [7, 8, 8, 9]], + dtype=np.float32)) + + def testAlignCorners3x3To6x6(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), [6, 6], + expected=np.array( + [[1, 1, 2, 2, 3, 3], [1, 1, 2, 2, 3, 3], [4, 4, 5, 5, 6, 6], + [4, 4, 5, 5, 6, 6], [7, 7, 8, 8, 9, 9], [7, 7, 8, 8, 9, 9]], + dtype=np.float32)) + + def testAlignCorners3x3To9x9(self): + # The expected matrix might look uneven in terms of how many of each number + # there is, but this is an artifact of doing the dilation and convolution + # iteratively. The behavior is less esoteric in the 3x3To12x12 case below. + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), [9, 9], + expected=np.array( + [[1, 2, 2, 2, 2, 3, 3, 3, 3], [4, 5, 5, 5, 5, 6, 6, 6, 6], + [4, 5, 5, 5, 5, 6, 6, 6, 6], [4, 5, 5, 5, 5, 6, 6, 6, 6], + [4, 5, 5, 5, 5, 6, 6, 6, 6], [7, 8, 8, 8, 8, 9, 9, 9, 9], + [7, 8, 8, 8, 8, 9, 9, 9, 9], [7, 8, 8, 8, 8, 9, 9, 9, 9], + [7, 8, 8, 8, 8, 9, 9, 9, 9]], + dtype=np.float32)) + + def testAlignCorners3x3To12x12(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), [12, 12], + expected=np.array([[1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3], + [1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3], + [1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9], + [7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9], + [7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9]], + dtype=np.float32)) + + class ResizeBilinearTest(xla_test.XLATestCase): def _assertForwardOpMatchesExpected(self, @@ -444,14 +555,14 @@ class ResizeBilinearTest(xla_test.XLATestCase): self.assertAllCloseAccordingToType(expected[np.newaxis, :, :, np.newaxis], out) - def testAlignCorners1x2To3x2(self): + def testAlignCorners1x2To3x3(self): for dtype in self.float_types: self._assertForwardOpMatchesExpected( np.array([[1, 2]], dtype=dtype), [3, 3], expected=np.array([[1, 1.5, 2], [1, 1.5, 2], [1, 1.5, 2]], dtype=np.float32)) - def testAlignCorners1x2To3x2Grad(self): + def testAlignCorners1x2To3x3Grad(self): for dtype in self.float_types: self._assertBackwardOpMatchesExpected( np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32), diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index d23fd125163d1afe8c7fd5e008d4b617ff4b2874..1521cc760b85b176acb27c1489640e92ef90e247 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -63,6 +63,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -80,6 +81,7 @@ int64 tf_xla_random_seed = 0; int32 tf_xla_test_repetitions = 20; int64 tf_xla_max_tensor_size = 10000LL; string* tf_xla_test_device_ptr; // initial value set in main() +string* tf_xla_reference_device_ptr; // initial value set in main() bool tf_xla_test_use_jit = true; string LocalDeviceToFullDeviceName(const string& device) { @@ -321,6 +323,9 @@ class OpTest : public ::testing::Test { // for use as reduction indices. Tensor RandomReductionIndices(int rank); + // Returns a random bit. + bool RandomBool(); + struct WindowedSpatialDims { Padding padding; std::vector kernel_dims; @@ -453,6 +458,11 @@ std::vector OpTest::RandomDims(int min_rank, int max_rank, return dims; } +bool OpTest::RandomBool() { + std::bernoulli_distribution d(0.5); + return d(generator()); +} + Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, absl::Span shape) { Tensor tensor(dtype, TensorShape(shape)); @@ -760,8 +770,22 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) { for (int i = 0; i < Tx.size(); ++i) { if (Tx(i) != Ty(i)) { return errors::InvalidArgument(absl::StrCat( - i, "-th tensor element isn't equal: ", Tx(i), " vs. ", Ty(i), - ". x = ", x.DebugString(), "y = ", y.DebugString())); + i, "-th tensor element isn't equal: ", Str(Tx(i)), " vs. ", + Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString())); + } + } + return Status::OK(); +} + +Status TensorsAreEqualImplBfloat16(const Tensor& x, const Tensor& y) { + auto Tx = x.flat(); + auto Ty = y.flat(); + for (int i = 0; i < Tx.size(); ++i) { + if (Tx(i) != Ty(i)) { + return errors::InvalidArgument(absl::StrCat( + i, "-th tensor element isn't equal: ", static_cast(Tx(i)), + " vs. ", static_cast(Ty(i)), ". x = ", x.DebugString(), + "y = ", y.DebugString())); } } return Status::OK(); @@ -797,6 +821,8 @@ Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol, return TensorsAreEqualImpl(a, b); case DT_BOOL: return TensorsAreEqualImpl(a, b); + case DT_BFLOAT16: + return TensorsAreEqualImplBfloat16(a, b); default: LOG(FATAL) << "Unexpected type : " << DataTypeString(a.dtype()); } @@ -829,8 +855,8 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( VLOG(1) << "Input: " << input_tensors.back().DebugString(); } - string cpu_device = - LocalDeviceToFullDeviceName(absl::StrCat(DEVICE_CPU, ":0")); + string reference_device = + LocalDeviceToFullDeviceName(*tf_xla_reference_device_ptr); string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr); DeviceNameUtils::ParsedName parsed_name; @@ -845,9 +871,9 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( std::vector expected_inputs, test_inputs; std::vector expected_fetches, test_fetches; Status status = builder.BuildGraph( - absl::StrCat("test", num_tests_, "_expected"), cpu_device, - /* use_jit= */ false, &graph, /* test_node_def= */ nullptr, - &expected_inputs, &expected_fetches); + absl::StrCat("test", num_tests_, "_expected"), reference_device, + /*use_jit=*/false, &graph, /*test_node_def=*/nullptr, &expected_inputs, + &expected_fetches); if (!status.ok()) { LOG(ERROR) << "Expected graph construction failed: " << status; return kFatalError; @@ -1371,6 +1397,19 @@ TEST_F(OpTest, Cast) { }); } +TEST_F(OpTest, CastBF16) { + Repeatedly([this]() { + DataType src_type, dst_type; + src_type = Choose({DT_FLOAT}); + dst_type = Choose({DT_BFLOAT16}); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast") + .RandomInput(src_type) + .Attr("SrcT", src_type) + .Attr("DstT", dst_type) + .Attr("Truncate", true)); + }); +} + TEST_F(OpTest, Ceil) { Repeatedly([this]() { return ExpectTfAndXlaOutputsAreClose( @@ -3346,11 +3385,41 @@ TEST_F(OpTest, ZerosLike) { }); } +// Example failing run: +// --tf_xla_reference_device=GPU:0 +// --tf_xla_test_use_jit=true --tf_xla_test_device=GPU:0 +// --tf_xla_test_repetitions=2 +// --gunit_filter='OpTest.FusedBatchNormTraining' +// --tf_xla_random_seed=2838146746 +TEST_F(OpTest, FusedBatchNormTraining) { + bool is_nhwc = RandomBool(); + std::vector x_dims = RandomDims(/*min_rank=*/4, /*max_rank=*/4, + /*min_size=*/5, /*max_size=*/20); + std::vector scale_dims = {x_dims[is_nhwc ? 3 : 1]}; + std::vector offset_dims = {x_dims[is_nhwc ? 3 : 1]}; + std::vector mean_dims = {0}; + std::vector variance_dims = {0}; + DataType type = DT_FLOAT; + Repeatedly([&] { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("FusedBatchNorm") + .RandomInput(type, x_dims) + .RandomInput(type, scale_dims) + .RandomInput(type, offset_dims) + .RandomInput(type, mean_dims) + .RandomInput(type, variance_dims) + .Attr("T", type) + .Attr("data_format", is_nhwc ? "NHWC" : "NCHW") + .Attr("epsilon", static_cast(1.001e-05)) + .Attr("is_training", true)); + }); +} } // anonymous namespace } // namespace tensorflow int main(int argc, char** argv) { tensorflow::tf_xla_test_device_ptr = new tensorflow::string("GPU:0"); + tensorflow::tf_xla_reference_device_ptr = new tensorflow::string("CPU:0"); std::vector flag_list = { tensorflow::Flag( "tf_xla_random_seed", &tensorflow::tf_xla_random_seed, @@ -3366,6 +3435,9 @@ int main(int argc, char** argv) { "Maximum number of elements for random input tensors."), tensorflow::Flag("tf_xla_test_device", tensorflow::tf_xla_test_device_ptr, "Tensorflow device type to use for test"), + tensorflow::Flag("tf_xla_reference_device", + tensorflow::tf_xla_reference_device_ptr, + "Tensorflow device type to use for reference"), tensorflow::Flag("tf_xla_test_use_jit", &tensorflow::tf_xla_test_use_jit, "Use JIT compilation for the operator under test"), }; diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py index 5c079d595c440cac644f5461154509abe7b1d1ed..47e0f384a4f1e46ccc35584aaff3a0aceff8a985 100644 --- a/tensorflow/compiler/tests/tensor_list_ops_test.py +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -23,24 +23,20 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import list_ops from tensorflow.python.platform import test -def scalar_shape(): - return ops.convert_to_tensor([], dtype=dtypes.int32) - - class ListOpsTest(xla_test.XLATestCase): def testElementShape(self): with self.cached_session() as sess, self.test_scope(): dim = array_ops.placeholder(dtypes.int32) - l = list_ops.tensor_list_reserve( - element_shape=(dim, 15), num_elements=20, - element_dtype=dtypes.float32) + l = list_ops.empty_tensor_list( + element_shape=(dim, 15), + element_dtype=dtypes.float32, + max_num_elements=20) e32 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32) e64 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int64) self.assertAllEqual(sess.run(e32, {dim: 10}), (10, 15)) @@ -48,25 +44,44 @@ class ListOpsTest(xla_test.XLATestCase): def testPushPop(self): with self.cached_session() as sess, self.test_scope(): - num = array_ops.placeholder(dtypes.int32) - l = list_ops.tensor_list_reserve( - element_shape=(7, 15), num_elements=num, element_dtype=dtypes.float32) + l = list_ops.empty_tensor_list( + element_shape=(7, 15), + element_dtype=dtypes.float32, + max_num_elements=10) l = list_ops.tensor_list_push_back( l, constant_op.constant(1.0, shape=(7, 15))) l = list_ops.tensor_list_push_back( l, constant_op.constant(2.0, shape=(7, 15))) l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) - self.assertAllEqual(sess.run(e2, {num: 10}), 2.0 * np.ones((7, 15))) - self.assertAllEqual(sess.run(e1, {num: 10}), 1.0 * np.ones((7, 15))) + self.assertAllEqual(sess.run(e2), 2.0 * np.ones((7, 15))) + self.assertAllEqual(sess.run(e1), 1.0 * np.ones((7, 15))) + + def testDoNotConstantFoldVariants(self): + with self.cached_session() as sess, self.test_scope(): + val = array_ops.placeholder(dtype=dtypes.float32) + l = list_ops.empty_tensor_list( + element_shape=(7, 15), + element_dtype=dtypes.float32, + max_num_elements=10) + # Note: Pushing a Placeholder will force the constant folding code + # to build a Const node with a DT_VARIANT output. This tests that XLA + # passes a cf_consider_fn which prevent folding such nodes. + l = list_ops.tensor_list_push_back( + l, array_ops.fill(value=val, dims=(7, 15))) + l = list_ops.tensor_list_push_back( + l, constant_op.constant(2.0, shape=(7, 15))) + l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + self.assertAllEqual(sess.run(e2, {val: 1.0}), 2.0 * np.ones((7, 15))) + self.assertAllEqual(sess.run(e1, {val: 1.0}), 1.0 * np.ones((7, 15))) def testPushPopSeparateLists(self): with self.cached_session() as sess, self.test_scope(): - num = array_ops.placeholder(dtypes.int32) - l = list_ops.tensor_list_reserve( - element_shape=scalar_shape(), - num_elements=num, - element_dtype=dtypes.float32) + l = list_ops.empty_tensor_list( + element_shape=[], + element_dtype=dtypes.float32, + max_num_elements=20) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l2 = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) l3 = list_ops.tensor_list_push_back(l, constant_op.constant(3.0)) @@ -75,22 +90,95 @@ class ListOpsTest(xla_test.XLATestCase): l2, e22 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32) l3, e31 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) l3, e32 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) - result = sess.run([e11, [e21, e22], [e31, e32]], {num: 20}) + result = sess.run([e11, [e21, e22], [e31, e32]]) self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]]) - def testEmptyTensorList(self): - dim = 7 + def testEmptyTensorListNoMax(self): with self.cached_session() as sess, self.test_scope(): - p = array_ops.placeholder(dtypes.int32) l = list_ops.empty_tensor_list( - element_shape=(p, 15), element_dtype=dtypes.float32) + element_shape=(7, 15), element_dtype=dtypes.float32) l = list_ops.tensor_list_push_back( - l, constant_op.constant(1.0, shape=(dim, 15))) + l, constant_op.constant(1.0, shape=(7, 15))) _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Use TensorListReserve instead"): - self.assertEqual(sess.run(e, {p: dim}), 1.0 * np.ones((dim, 15))) + "Set the max number of elements"): + self.assertEqual(sess.run(e), 1.0 * np.ones((7, 15))) + def testEmptyTensorListMax(self): + with self.cached_session() as sess, self.test_scope(): + l = list_ops.empty_tensor_list( + element_shape=(10, 15), element_dtype=dtypes.float32, + max_num_elements=2) + l = list_ops.tensor_list_push_back( + l, array_ops.fill(value=3.0, dims=(10, 15))) + _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + self.assertAllEqual(sess.run(e), 3.0 * np.ones((10, 15))) + + def testListFromTensor(self): + with self.cached_session(), self.test_scope(): + t = constant_op.constant([1.0, 2.0]) + l = list_ops.tensor_list_from_tensor(t, element_shape=[]) + e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + self.assertAllEqual(e, 1.0) + l, e0 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + self.assertAllEqual(e0, 2.0) + l, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + self.assertAllEqual(e1, 1.0) + self.assertAllEqual(list_ops.tensor_list_length(l), 0) + + def testGetSet(self): + with self.cached_session(), self.test_scope(): + t = constant_op.constant([1.0, 2.0]) + l = list_ops.tensor_list_from_tensor(t, element_shape=[]) + e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + self.assertAllEqual(e0, 1.0) + l = list_ops.tensor_list_set_item(l, 0, 3.0) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [3.0, 2.0]) + + def testGetSetReserved(self): + with self.cached_session(), self.test_scope(): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, element_shape=[], num_elements=2) + e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + self.assertAllEqual(e0, 0.0) + l = list_ops.tensor_list_set_item(l, 0, 3.0) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [3.0, 0.0]) + + def testGetSetReservedNonScalar(self): + with self.cached_session() as sess, self.test_scope(): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, + element_shape=(7, 15), + num_elements=2) + l = list_ops.tensor_list_set_item( + l, 0, constant_op.constant(1.0, shape=(7, 15))) + e1 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + e2 = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32) + self.assertAllEqual(sess.run(e1), np.ones((7, 15))) + self.assertAllEqual(sess.run(e2), np.zeros((7, 15))) + + def testStack(self): + with self.cached_session(), self.test_scope(): + l = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, + element_shape=[], + max_num_elements=2) + l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) + e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + self.assertAllEqual(e, 1.0) + l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t.shape.as_list(), [None]) + self.assertAllEqual(t, [1.0, 2.0]) + + def testStackWithUninitializedTensors(self): + with self.cached_session(), self.test_scope(): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, element_shape=[], num_elements=3) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [0., 0., 0.]) if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 95c9e7ffd4651642781143c2c1940b0e51e1e470..3c2875ba477fa71e9e56a18d10efe0808533dd03 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -647,7 +647,7 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) - ctypes = {np.complex64: np.float32} + ctypes = {np.complex64: np.float32, np.complex128: np.float64} self._assertOpOutputMatchesExpected( math_ops.abs, np.array([[3 - 4j, -1j, np.inf]], dtype=dtype), diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index fcd7ac5ba1ca5049246e93e6f5f76746fb28c6b8..18c5870e0decb686f4df1c16bbb4a340c93ad21d 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -485,7 +485,7 @@ class SliceAssignTest(xla_test.XLATestCase): checker2[None] = [6] # new axis def testUninitialized(self): - with self.assertRaisesRegexp(errors.InvalidArgumentError, + with self.assertRaisesRegexp(errors.FailedPreconditionError, "uninitialized variable"): with self.test_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable([1, 2]) diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a67e511826ae161e78d504c1513934065cbfd19f --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -0,0 +1,440 @@ +# Description: +# Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow +# and provide TensorRT operators and converter package. +# APIs are meant to change over time. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", + "tf_copts", + "tf_cuda_library", + "tf_custom_op_library", + "tf_custom_op_library_additional_deps", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", +) +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load( + "@local_config_tensorrt//:build_defs.bzl", + "if_tensorrt", +) + +tf_cuda_cc_test( + name = "tensorrt_test_cc", + size = "small", + srcs = ["tensorrt_test.cc"], + tags = [ + "no_windows", + "nomac", + ], + deps = [ + "//tensorflow/core:gpu_init", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ] + if_tensorrt([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_custom_op_library( + name = "python/ops/_trt_ops.so", + srcs = [ + "ops/get_serialized_resource_op.cc", + "ops/trt_engine_op.cc", + ], + deps = [ + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +cc_library( + name = "trt_op_kernels", + srcs = [ + "kernels/get_serialized_resource_op.cc", + "kernels/trt_engine_op.cc", + ], + hdrs = [ + "kernels/trt_engine_op.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":test_utils", + ":trt_allocator", + ":trt_conversion", + ":trt_logging", + ":trt_plugins", + ":trt_resources", + ":utils", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "//tensorflow/core:gpu_headers_lib", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:stream_executor_headers_lib", + "//tensorflow/core/grappler/costs:graph_properties", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]) + tf_custom_op_library_additional_deps(), + # TODO(laigd): fix this by merging header file in cc file. + alwayslink = 1, # buildozer: disable=alwayslink-with-hdrs +) + +tf_cuda_cc_test( + name = "get_serialized_resource_op_test", + size = "small", + srcs = ["kernels/get_serialized_resource_op_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":get_serialized_resource_op_op_lib", + ":trt_op_kernels", + ":trt_resources", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + ], +) + +tf_gen_op_libs( + op_lib_names = [ + "trt_engine_op", + "get_serialized_resource_op", + ], +) + +tf_cuda_library( + name = "trt_logging", + srcs = ["utils/trt_logger.cc"], + hdrs = ["utils/trt_logger.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_gen_op_wrapper_py( + name = "trt_ops", + deps = [ + ":get_serialized_resource_op_op_lib", + ":trt_engine_op_op_lib", + ":trt_logging", + ], +) + +tf_custom_op_py_library( + name = "trt_ops_loader", + srcs = ["python/ops/trt_ops.py"], + dso = [ + "python/ops/_trt_ops.so", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), + kernels = [ + ":trt_op_kernels", + ":trt_engine_op_op_lib", + ":get_serialized_resource_op_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:resources", + ], +) + +tf_cuda_library( + name = "trt_resources", + srcs = [ + "utils/trt_int8_calibrator.cc", + "utils/trt_resource_manager.cc", + "utils/trt_resources.cc", + ], + hdrs = [ + "utils/trt_int8_calibrator.h", + "utils/trt_lru_cache.h", + "utils/trt_resource_manager.h", + "utils/trt_resources.h", + ], + deps = [ + ":trt_allocator", + ":trt_logging", + ":utils", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_cuda_library( + name = "trt_allocator", + srcs = ["utils/trt_allocator.cc"], + hdrs = ["utils/trt_allocator.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_cc_test( + name = "trt_allocator_test", + size = "small", + srcs = ["utils/trt_allocator_test.cc"], + tags = [ + "no_windows", + "nomac", + ], + deps = [ + ":trt_allocator", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "trt_lru_cache_test", + size = "small", + srcs = ["utils/trt_lru_cache_test.cc"], + tags = [ + "no_windows", + "nomac", + ], + deps = [ + ":trt_resources", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +# Library for the node-level conversion portion of TensorRT operation creation +tf_cuda_library( + name = "trt_conversion", + srcs = [ + "convert/convert_graph.cc", + "convert/convert_nodes.cc", + "convert/trt_optimization_pass.cc", + ], + hdrs = [ + "convert/convert_graph.h", + "convert/convert_nodes.h", + "convert/trt_optimization_pass.h", + ], + deps = [ + ":segment", + ":test_utils", + ":trt_allocator", + ":trt_plugins", + ":trt_logging", + ":trt_resources", + ":utils", + "@com_google_absl//absl/strings", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core:framework", + "//tensorflow/core:framework_lite", + "//tensorflow/core:gpu_runtime", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:devices", + "//tensorflow/core/grappler/clusters:virtual_cluster", + "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/optimizers:meta_optimizer", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]) + tf_custom_op_library_additional_deps(), +) + +tf_cuda_cc_test( + name = "convert_graph_test", + size = "medium", + srcs = ["convert/convert_graph_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":trt_conversion", + "@com_google_googletest//:gtest", + "@com_google_absl//absl/strings", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:direct_session", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_cuda_cc_test( + name = "convert_nodes_test", + size = "medium", + srcs = ["convert/convert_nodes_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":trt_logging", + ":trt_conversion", + ":trt_plugins", + "@com_google_googletest//:gtest", + "@com_google_absl//absl/strings", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ] + if_tensorrt([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_tensorrt//:tensorrt", + ]), +) + +# Library for the segmenting portion of TensorRT operation creation +cc_library( + name = "segment", + srcs = ["segment/segment.cc"], + hdrs = [ + "segment/segment.h", + "segment/union_find.h", + ], + deps = [ + "//tensorflow/core:graph", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + "@protobuf_archive//:protobuf_headers", + ], +) + +tf_cc_test( + name = "segment_test", + size = "small", + srcs = ["segment/segment_test.cc"], + tags = [ + "no_windows", + "nomac", + ], + deps = [ + ":segment", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +# Library for the plugin factory +tf_cuda_library( + name = "trt_plugins", + srcs = [ + "plugin/trt_plugin.cc", + "plugin/trt_plugin_factory.cc", + "plugin/trt_plugin_utils.cc", + ], + hdrs = [ + "plugin/trt_plugin.h", + "plugin/trt_plugin_factory.h", + "plugin/trt_plugin_utils.h", + ], + deps = [ + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_cuda_cc_test( + name = "trt_plugin_factory_test", + size = "small", + srcs = ["plugin/trt_plugin_factory_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":trt_plugins", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ] + if_tensorrt([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_tensorrt//:tensorrt", + ]), +) + +cc_library( + name = "utils", + srcs = ["convert/utils.cc"], + hdrs = ["convert/utils.h"], + copts = tf_copts(), + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "test_utils", + srcs = ["utils/test_utils.cc"], + hdrs = ["utils/test_utils.h"], + deps = [ + "//tensorflow/core:lib", + "@com_googlesource_code_re2//:re2", + ], +) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc similarity index 94% rename from tensorflow/contrib/tensorrt/convert/convert_graph.cc rename to tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index bf2de94e04ae3f6817f7a679ce9fd88e750827dd..1fdb099cc1d658b4259177e357b639ea72d636d0 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h" #include #include @@ -24,13 +24,14 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" -#include "tensorflow/contrib/tensorrt/segment/segment.h" -#include "tensorflow/contrib/tensorrt/test/utils.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/segment/segment.h" +#include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" @@ -63,8 +64,8 @@ limitations under the License. namespace tensorflow { namespace tensorrt { namespace convert { -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; // Returns compiled TRT version information {Maj, Min, Patch} std::vector GetLinkedTensorRTVersion() { @@ -151,7 +152,7 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { if (precision_mode_ == INT8MODE && quantize_ops.count(node->type_string())) { is_supported_op_type = true; } - // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc) + // LINT.ThenChange(//tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc) if (!is_supported_op_type) { return errors::Unimplemented("Op type ", node->type_string(), " is not supported"); @@ -334,13 +335,12 @@ struct EdgePtrCompare { tensorflow::Status GetEngineInfo( const tensorflow::Graph* g, const tensorflow::grappler::GraphProperties& graph_properties, - const std::set& segment_nodes, + const std::set& segment_nodes, const std::unordered_map& node_map, const std::vector& reverse_topo_order, EngineInfo* info) { - std::vector subgraph_node_ids; // Topologically sorted node ids. - std::set subgraph_node_names = segment_nodes; - std::set added_const_node_ids; // Used to prevent double insertion. + std::vector subgraph_nodes; // Topologically sorted nodes. + std::set added_const_nodes; // Used to prevent double insertion. std::set segment_devices; // Map from src_node_name+port to the unique port numbers of the TRT op, where @@ -352,22 +352,37 @@ tensorflow::Status GetEngineInfo( std::unordered_map input_to_engine_port, output_to_engine_port; for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend(); ++it) { - const auto& node_name = (*it)->name(); - if (segment_nodes.count(node_name) == 0) continue; - auto node = *it; + const Node* node = *it; + if (segment_nodes.count(node) == 0) continue; auto node_device = node->requested_device(); if (!node_device.empty()) { - segment_devices.insert(node_device); + // If device is CPU, treat as if no device was assigned. Don't add CPU to + // segment_device because that would cause a segfault in + // GetDeviceAndAllocator. This is because GetDeviceAndAllocator assumes + // any already set device is a GPU. + DeviceNameUtils::ParsedName parsed_name; + DeviceNameUtils::ParseFullName(node_device, &parsed_name); + if (parsed_name.type == "CPU") { + VLOG(1) << "Node " << node->name() << " was assigned to the CPU. " + << "Attempting to place on GPU."; + } else { + segment_devices.insert(node_device); + } } else { if (node->has_assigned_device_name()) { + // It appears that nodes will not have assigned devices at this point in + // execution. segment_devices.insert(node->assigned_device_name()); } else { VLOG(2) << "Node " << node->name() << " neither have requested device nor assigned device"; } } + subgraph_nodes.push_back(node); + const int node_id = node->id(); - subgraph_node_ids.push_back(node_id); + const string& node_name = node->name(); + // Create input connections. Sort edges first to make determnistic since // in_edges is a set of pointers. std::vector in_edges(node->in_edges().begin(), @@ -375,7 +390,7 @@ tensorflow::Status GetEngineInfo( std::sort(in_edges.begin(), in_edges.end(), EdgePtrCompare()); for (const auto edge : in_edges) { auto input_node = edge->src(); - if (input_node->IsSource() || segment_nodes.count(input_node->name())) { + if (input_node->IsSource() || segment_nodes.count(input_node)) { continue; } if (edge->IsControlEdge()) { @@ -392,12 +407,11 @@ tensorflow::Status GetEngineInfo( // // Note that the segmenter already ensure that the constant data input // is valid and suppported by the engine. - if (!added_const_node_ids.insert(input_node->id()).second) { + if (!added_const_nodes.insert(input_node).second) { // Already added before. continue; } VLOG(1) << "Adding const node " << input_node->name(); - QCHECK(subgraph_node_names.insert(input_node->name()).second); // Since we already add (duplicate) the const input node to the segment // graphdef, it's now not a data dependency any more, but to make the // dependency correct we still add a control dependency. @@ -428,7 +442,7 @@ tensorflow::Status GetEngineInfo( std::sort(out_edges.begin(), out_edges.end(), EdgePtrCompare()); for (const auto edge : out_edges) { auto output_node = edge->dst(); - if (output_node->IsSink() || segment_nodes.count(output_node->name())) { + if (output_node->IsSink() || segment_nodes.count(output_node)) { continue; } if (edge->IsControlEdge()) { @@ -456,12 +470,11 @@ tensorflow::Status GetEngineInfo( } // For each segment node in topological order. // Construct the const nodes first. - subgraph_node_ids.insert(subgraph_node_ids.begin(), - added_const_node_ids.begin(), - added_const_node_ids.end()); + subgraph_nodes.insert(subgraph_nodes.begin(), added_const_nodes.begin(), + added_const_nodes.end()); TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef( - g, graph_properties, subgraph_node_names, subgraph_node_ids, - &info->connections, &info->segment_graph_def, &info->engine_name)); + g, graph_properties, subgraph_nodes, &info->connections, + &info->segment_graph_def, &info->engine_name)); // TODO(sami): This should not happen once segmenter is updated. if (segment_devices.size() == 1) { info->device = *segment_devices.begin(); @@ -654,14 +667,8 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, segment_string = info.segment_graph_def.SerializeAsString(); } - // TODO(aaroey): use enum instead, and add a helper method to do the - // conversion. string prec_string; TF_RETURN_IF_ERROR(GetPrecisionModeName(info.precision_mode, &prec_string)); - if (info.precision_mode == INT8MODE && calibrate_int8 && - !TRTResourceManager::instance()->getManager("TRTCalibration")) { - LOG(ERROR) << "Failed to construct calibration storage"; - } tensorflow::NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp"); if (!info.device.empty()) node_builder.Device(info.device); if (VLOG_IS_ON(1)) { @@ -677,7 +684,7 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, } if (info.engine_type == EngineInfo::EngineType::TRTStatic && - info.cached_engine_batches.size()) { + !info.cached_engine_batches.empty()) { LOG(WARNING) << "Cached engine batches are ignored for static engines"; } tensorflow::NodeDef trt_node; @@ -691,7 +698,6 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, .Attr("serialized_segment", segment_string) .Attr("calibration_data", "") .Attr("max_cached_engines_count", info.maximum_cached_engines) - .Attr("cached_engine_batches", {max_batch_size}) .Attr("workspace_size_bytes", info.max_workspace_size_bytes) .Attr("precision_mode", prec_string) .Attr("use_calibration", info.use_calibration) @@ -1033,27 +1039,31 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { cudaSetDevice(cuda_device_id); auto status = CreateTRTNode(engine_segments, i, params.max_batch_size, &graph, alloc.get(), &engine_nodes); - // If status is ok, we successfully added the node to the graph and can - // remove segment ops. Otherwise graph is not modified. + string msg = StrCat("TensorRT node ", engine.engine_name, " added for segment ", i, " consisting of ", converted_segments.at(i).first.size(), " nodes"); if (status.ok()) { LOG(INFO) << msg << " succeeded."; - for (auto node_name : converted_segments.at(i).first) { - graph.RemoveNode(node_map.at(node_name)); - } } else { // Graph is not modified. LOG(WARNING) << msg << " failed: " << status << ". Fallback to TF..."; } if (VLOG_IS_ON(1)) { msg = "Segment consists of nodes: "; - for (const string& node_name : converted_segments.at(i).first) { - StrAppend(&msg, node_name, ", "); + for (const Node* node : converted_segments.at(i).first) { + StrAppend(&msg, node->name(), ", "); } VLOG(1) << msg; } + + // If status is ok, we successfully added the node to the graph and can + // remove segment ops. Otherwise graph is not modified. + if (status.ok()) { + for (const Node* node : converted_segments.at(i).first) { + graph.RemoveNode(const_cast(node)); + } + } } cudaSetDevice(old_cuda_device); graph.ToGraphDef(params.output_graph_def); diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h similarity index 94% rename from tensorflow/contrib/tensorrt/convert/convert_graph.h rename to tensorflow/compiler/tf2tensorrt/convert/convert_graph.h index 1f39f56f6392ba33af3d74fec12c326ed4451cb6..fb82a430c632781047487a280e23e7da4c385929 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_GRAPH_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_GRAPH_H_ #include -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/costs/graph_properties.h" @@ -123,4 +123,4 @@ std::pair GetDeviceAndAllocator( #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_GRAPH_H_ diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc similarity index 98% rename from tensorflow/contrib/tensorrt/convert/convert_graph_test.cc rename to tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc index 2d2bfeb192c1893824c7b30bfad593c62c203392..a3c3a8ac6561259c974aebb6c6eeac05c71b7161 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h" #include #include #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc similarity index 94% rename from tensorflow/contrib/tensorrt/convert/convert_nodes.cc rename to tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index adf8831b960172fc29b5d631e5b0533318d4764d..c08582a42e24fd55e785ad045725e06f1d414bfd 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include #include @@ -24,11 +24,12 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" // NOLINT @@ -43,6 +44,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/tensor_coding.h" #include "tensorflow/core/platform/types.h" @@ -81,9 +83,9 @@ const char* const kInputPHName = "TensorRTInputPH_"; const char* const kOutputPHName = "TensorRTOutputPH_"; namespace convert { +using absl::StrAppend; +using absl::StrCat; using ::tensorflow::str_util::Split; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, nvinfer1::DataType* trt_dtype) { @@ -334,6 +336,21 @@ Status Converter::GetTrtBroadcastShape( return Status::OK(); } +nvinfer1::ITensor* Converter::CreateConstantLayer( + const TRT_ShapedWeights& weights, const nvinfer1::Dims& dims) { + nvinfer1::Weights trt_weights = weights.GetTrtWeights(); + nvinfer1::IConstantLayer* layer = network()->addConstant(dims, trt_weights); + if (!layer) return nullptr; + const nvinfer1::DataType trt_dtype = trt_weights.type; + nvinfer1::ITensor* trt_tensor = layer->getOutput(0); + // TODO(laigd): there is a bug in TensorRT 5.0 library that, if we don't set + // the data type below, it will always be kFLOAT regardless what the data type + // of the weights is. Once NVIDIA fixes this bug, we should remove the data + // type setting logic below and test should still pass. + trt_tensor->setType(trt_dtype); + return trt_tensor; +} + inline bool DimsEqual(const nvinfer1::Dims& dim_l, const nvinfer1::Dims& dim_r) { if (dim_l.nbDims != dim_r.nbDims) { @@ -879,6 +896,8 @@ Status Converter::ConvertNode(const NodeDef& node_def) { // We need to check the name before setting it. If the input is one of the // engine input, setting the name here will overwrite engine input // bindings which will cause runtime error. + // TODO(tmorris): Remove this work-around once we use TRT's IIdentityLayer + // in ConvertIdentity. if (output.is_tensor()) { const char* tensor_name = output.tensor()->getName(); if (!tensorflow::str_util::StartsWith(tensor_name, kInputPHName)) { @@ -939,6 +958,22 @@ Status Converter::RenameAndMarkOutputTensors( if (tensor == nullptr) { return errors::NotFound("Output tensor not found: ", output.first); } + // Check if this tensor has already been marked as an output. + // ConvertIdentity can cause the same tensor to be repeated in + // output_tensors, which can cause us to overwrite the name of the output + // tensor binding. For example, if we rename OutputPH_0 to OutputPH_1 then + // we won't be able to locate OutputPH_0 during runtime. To fix this, + // duplicate the tensor using no-op shuffle. + // TODO(tmorris): Remove this work-around once we use TRT's IIdentityLayer + // in ConvertIdentity. + if (tensorflow::str_util::StartsWith(tensor->getName(), kOutputPHName)) { + // Using shuffle layer for identity by not setting reshape or transpose. + nvinfer1::IShuffleLayer* layer = network()->addShuffle(*tensor); + TFTRT_RETURN_ERROR_IF_NULLPTR( + layer, StrCat("Output Copy for ", tensor->getName())); + MarkQuantizationRangesAsInferrable(tensor, layer->getOutput(0)); + tensor = layer->getOutput(0); + } tensor->setName(output.second.c_str()); VLOG(1) << "Marking output tensor " << output.first << ", as output tensor " << output.second; @@ -1086,10 +1121,8 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, *tensor = layer->getOutput(0); } } else { - nvinfer1::IConstantLayer* layer = - this->network()->addConstant(dims, input.weights().GetTrtWeights()); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Reshape"); - *tensor = layer->getOutput(0); + *tensor = CreateConstantLayer(input.weights(), dims); + TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, "TF-TRT Internal Reshape"); if (precision_mode() == INT8MODE && !use_calibration()) { // If we are in int8 mode and not calibrating, we need to explicitly set a // quantization range for the output tensor of the IConstantLayer. Here we @@ -1538,6 +1571,11 @@ enum class ConvolutionType { DEFAULT, DEPTHWISE_CONV }; tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; + if (inputs.size() != 2) { + return tensorflow::errors::InvalidArgument("Two inputs are expected for ", + node_def.op(), ", at ", + node_def.name()); + } if (inputs.at(0).is_weights()) { return tensorflow::errors::Unimplemented( node_def.op(), " is only implemented for tensors, not weights, at ", @@ -1549,39 +1587,61 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { node_def.name()); } TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); - VLOG(2) << "weight shape: " << weights_rsck.DebugString(); if (weights_rsck.shape_.nbDims != 4) { - return tensorflow::errors::Internal( - "Conv2D expects kernel of dimension 4, at: " + node_def.name()); + return tensorflow::errors::InvalidArgument( + "Conv2D expects kernel of dimension 4, at " + node_def.name()); } + TFAttrs attrs(node_def); + auto data_format = attrs.get("data_format"); + int c_index = (data_format == "NHWC") ? 3 : 1; + int h_index = (data_format == "NHWC") ? 1 : 2; + int w_index = (data_format == "NHWC") ? 2 : 3; + auto tf_dilations = attrs.get>("dilations"); + if (tf_dilations.size() != 4) { + return tensorflow::errors::InvalidArgument( + "Convolution dilations field must specify 4 dimensions, at ", + node_def.name()); + } + if (tf_dilations[0] != 1 || tf_dilations[c_index] != 1) { + return tensorflow::errors::Unimplemented( + "Dilation rate must be 1 for batch and channel dimensions, at ", + node_def.name()); + } + const nvinfer1::DimsHW dilation(tf_dilations[h_index], tf_dilations[w_index]); + + const auto tf_stride = attrs.get>("strides"); + if (tf_stride.size() != 4) { + return tensorflow::errors::InvalidArgument( + "Convolution strides field must specify 4 dimensions, at ", + node_def.name()); + } + if (tf_stride[0] != 1 || tf_stride[c_index] != 1) { + return tensorflow::errors::Unimplemented( + "Stride must be 1 for batch and channel dimensions, at ", + node_def.name()); + } + const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); if (params->validation_only) return tensorflow::Status::OK(); const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - TFAttrs attrs(node_def); - int h_index = 2; - int w_index = 3; - auto data_format = attrs.get("data_format"); - if (data_format == "NHWC") { + // Transpose to NCHW (NCHW is required for IConvLayer). + const bool need_transpose = (data_format == "NHWC"); + if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( const_cast(tensor), {0, 3, 1, 2}, &tensor)); - h_index = 1; - w_index = 2; - // TODO(jie): transpose it } - - // tensor after transpose (NCHW) + // Dimensions of transposed tensor. const auto tensor_dim = tensor->getDimensions(); - int num_groups = group; - if (num_groups == 0) num_groups = tensor_dim.d[0]; // depthwise convolution - VLOG(2) << "groups count: " << num_groups; + // For depthwise convolution, group will be 0 so set num_groups to size of + // input's channel dim. For a non-depthwise conv, num_groups will be 1. + const int num_groups = (group == 0) ? tensor_dim.d[0] : group; if (params->converter->precision_mode() == FP16MODE) { weights_rsck = ConvertFP32ToFP16(params->weight_store, inputs.at(1).weights()); } - TRT_ShapedWeights weights = params->weight_store->GetTempWeights(weights_rsck); ReorderRSCKToKCRS(weights_rsck, &weights, num_groups); @@ -1590,35 +1650,22 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { nvinfer1::DimsHW kernel_size; kernel_size.h() = weights.shape_.d[2]; kernel_size.w() = weights.shape_.d[3]; - VLOG(2) << "RSCK: " << weights.DebugString(); - VLOG(2) << "kernel size: " << kernel_size.h() << ", " << kernel_size.w(); - - // TODO(jie): stride. (NHWC/NCHW) - const auto tf_stride = attrs.get>("strides"); - VLOG(2) << "h_INDEX" << h_index << ", w_index " << w_index; - VLOG(2) << "stride: " << tf_stride[0] << tf_stride[1] << tf_stride[2] - << tf_stride[3]; - const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); + // Add padding. std::vector> padding; - // TODO(jie): padding. if (attrs.get("padding") == "SAME") { - // This is NCHW tensor with no batch dimension. - // 1 -> h - // 2 -> w + nvinfer1::DimsHW effective_kernel_size = kernel_size; + effective_kernel_size.h() += (kernel_size.h() - 1) * (dilation.h() - 1); + effective_kernel_size.w() += (kernel_size.w() - 1) * (dilation.w() - 1); padding = CreateSamePadding( - stride, kernel_size, + stride, effective_kernel_size, {static_cast(tensor_dim.d[1]), static_cast(tensor_dim.d[2])}); } else { padding = {{0, 0}, {0, 0}}; } - if (padding[0].first != padding[0].second || padding[1].first != padding[1].second) { - // TODO(jie): handle asymmetric padding - VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second - << padding[1].first << padding[1].second; - VLOG(2) << "TENSOR before: " << DebugString(tensor->getDimensions()); + // Handle asymmetric padding. auto pad_layer = params->converter->network()->addPadding( *const_cast(tensor), nvinfer1::DimsHW(padding[0].first, padding[1].first), @@ -1628,24 +1675,23 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { const_cast(tensor), pad_layer->getOutput(0)); padding = {{0, 0}, {0, 0}}; tensor = pad_layer->getOutput(0); - VLOG(2) << "TENSOR after: " << DebugString(tensor->getDimensions()); } + // Add convolution. nvinfer1::IConvolutionLayer* layer = params->converter->network()->addConvolution( *const_cast(tensor), noutput, kernel_size, weights.GetTrtWeights(), biases.GetTrtWeights()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - layer->setStride(stride); layer->setPadding({padding[0].first, padding[1].first}); layer->setName(node_def.name().c_str()); layer->setNbGroups(num_groups); + layer->setDilation(dilation); const nvinfer1::ITensor* output_tensor = layer->getOutput(0); - VLOG(2) << "TENSOR out: " << DebugString(output_tensor->getDimensions()); - VLOG(2) << "data_format: " << data_format; - if (data_format == "NHWC") { - // TODO(jie): transpose it back! + + // Restore transpose. + if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( const_cast(output_tensor), {0, 2, 3, 1}, &output_tensor)); @@ -1694,6 +1740,13 @@ Status BinaryTensorOpTensor(OpConverterParams* params, "Unsupported binary op broadcast scheme for op ", node_def.name(), ": ", status.error_message()); } + TFAttrs attrs(node_def); + nvinfer1::DataType dtype = attrs.get("T"); + if (dtype == nvinfer1::DataType::kINT32) { + return errors::Unimplemented("Binary op ", node_def.op(), + " does not support INT32, at ", + node_def.name()); + } if (params->validation_only) return Status::OK(); const nvinfer1::ITensor* tensor_l = nullptr; @@ -1710,8 +1763,6 @@ Status BinaryTensorOpTensor(OpConverterParams* params, } // Check type consistency. - TFAttrs attrs(node_def); - nvinfer1::DataType dtype = attrs.get("T"); TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype) << DebugString(tensor_l->getType()) << " vs " << DebugString(dtype); TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype) @@ -2534,22 +2585,18 @@ tensorflow::Status ConvertRelu6(OpConverterParams* params) { auto weights_ptr = static_cast(const_cast(weights.GetValues())); weights_ptr[0] = 6.0f; - nvinfer1::IConstantLayer* const6_layer = - params->converter->network()->addConstant(dims, weights.GetTrtWeights()); - TFTRT_RETURN_ERROR_IF_NULLPTR(const6_layer, node_def.name()); - params->converter->ProvideQuantizationRange(const6_layer->getOutput(0), 0.0f, - 6.0f); + nvinfer1::ITensor* const6_tensor = + params->converter->CreateConstantLayer(weights, dims); + TFTRT_RETURN_ERROR_IF_NULLPTR(const6_tensor, node_def.name()); + params->converter->ProvideQuantizationRange(const6_tensor, 0.0f, 6.0f); // ElementWise Min Operation // Min op is a nop for INT8 execution path, as the input tensor // to this layer will only have values in range [0.f, 6.0f]. - const nvinfer1::ITensor* tensor_l = relu_layer->getOutput(0); - const nvinfer1::ITensor* tensor_r = const6_layer->getOutput(0); nvinfer1::IElementWiseLayer* relu6_layer = params->converter->network()->addElementWise( - *const_cast(tensor_l), - *const_cast(tensor_r), - nvinfer1::ElementWiseOperation::kMIN); + *const_cast(relu_layer->getOutput(0)), + *const6_tensor, nvinfer1::ElementWiseOperation::kMIN); TFTRT_RETURN_ERROR_IF_NULLPTR(relu6_layer, node_def.name()); nvinfer1::ITensor* output_tensor = relu6_layer->getOutput(0); params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 6.0f); @@ -2566,12 +2613,18 @@ tensorflow::Status ConvertBiasAdd(OpConverterParams* params) { return errors::InvalidArgument("Input expects tensor and weights, at ", node_def.name()); } + TFAttrs attrs(node_def); + tensorflow::DataType tf_dtype = attrs.get("T"); + if (tf_dtype != DataType::DT_FLOAT && tf_dtype != DataType::DT_HALF) { + return errors::Unimplemented("Data type is not supported, for node ", + node_def.name(), " got ", + DataTypeString(tf_dtype)); + } if (params->validation_only) return Status::OK(); nvinfer1::ITensor* tensor = const_cast(inputs.at(0).tensor()); const nvinfer1::Dims original_dims = tensor->getDimensions(); - TFAttrs attrs(node_def); const string data_format = attrs.get("data_format"); const int channel_index = (data_format == "NHWC" ? original_dims.nbDims - 1 : 0); @@ -2661,43 +2714,69 @@ tensorflow::Status ConvertBiasAdd(OpConverterParams* params) { return Status::OK(); } -Status GetTensorDimsWithProtoShape(const Tensor& tensor, - int tensor_proto_array_len, - nvinfer1::Dims* dims) { +void GetTensorDimsWithProtoShape(const Tensor& tensor, nvinfer1::Dims* dims) { if (tensor.dims() > 0) { *dims = GetTrtDimsForTensor(tensor); - if (TrtDimsNumElements(*dims) != tensor_proto_array_len && - tensor_proto_array_len != 1) { - return errors::InvalidArgument( - "Broadcast on weights only supports kCHANNEL and kUNIFORM"); - } } else { dims->nbDims = 1; // No dimension provided. Flatten it. - dims->d[0] = tensor_proto_array_len; + dims->d[0] = tensor.NumElements(); dims->type[0] = nvinfer1::DimensionType::kSPATIAL; for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; ++i) { dims->d[i] = 0; } } - return Status::OK(); } -template -Status TfTensorToTrtWeights(const DataType dtype, const Tensor& tensor, - const CType* tensor_proto_array, - int tensor_proto_array_len, TrtWeightStore* store, +Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store, TRT_ShapedWeights* weights) { + const DataType dtype = tensor.dtype(); + + // We always convert the integer constants to INT32, since TRT INT8 is for + // quantized inference. + // + // TODO(aaroey): FP16 will remain in half format and is not converted to + // FP32, but the converter currently uses all float weights as FP32. Fix + // this. + const DataType converted_dtype = + (dtype == DT_INT16 || dtype == DT_INT8 || dtype == DT_UINT8 ? DT_INT32 + : dtype); + + // Verify that the dtype is supported by TensorRT. Otherwise, return an error. + nvinfer1::DataType trt_dtype; + TF_RETURN_IF_ERROR(ConvertDType(converted_dtype, &trt_dtype)); + + if (tensor.NumElements() == 0) { + // Return empty weights having converted dtype. + *weights = TRT_ShapedWeights(converted_dtype); + return Status::OK(); + } + nvinfer1::Dims weight_dims; - TF_RETURN_IF_ERROR(GetTensorDimsWithProtoShape(tensor, tensor_proto_array_len, - &weight_dims)); - *weights = store->GetTempWeights(dtype, weight_dims); - void* dst = const_cast(weights->GetValues()); - if (tensor_proto_array_len == 1) { - std::fill_n((CType*)dst, TrtDimsNumElements(weight_dims), - *tensor_proto_array); + GetTensorDimsWithProtoShape(tensor, &weight_dims); + *weights = weight_store->GetTempWeights(converted_dtype, weight_dims); + + // Copy the tensor directly if the tensor does not require cast to the + // supported type. + if (converted_dtype == dtype) { + char* dst = static_cast(const_cast(weights->GetValues())); + memcpy(dst, tensor.tensor_data().data(), tensor.TotalBytes()); + return Status::OK(); + } + + // Copy tensor elements after casting them to the converted DataType. + int32* dst = static_cast(const_cast(weights->GetValues())); + if (dtype == DT_INT16) { + const int16* src = tensor.flat().data(); + std::copy(src, src + tensor.NumElements(), dst); + } else if (dtype == DT_INT8) { + const int8* src = tensor.flat().data(); + std::copy(src, src + tensor.NumElements(), dst); } else { - memcpy(dst, tensor_proto_array, weights->size_bytes()); + // dtype can only be DT_UINT8 at this point. + TFTRT_CHECK_EQ_TYPE(dtype, DT_UINT8); + const uint8* src = tensor.flat().data(); + std::copy(src, src + tensor.NumElements(), dst); } return Status::OK(); } @@ -2715,15 +2794,6 @@ tensorflow::Status ConvertConst(OpConverterParams* params) { "Constant node is expected to have empty input list: ", node_def.name()); } - TFAttrs attrs(node_def); - const DataType dtype = attrs.get("dtype"); - // We always convert the integer constants to kINT32, since TRT kINT8 is for - // quantized inference. - const DataType converted_dtype = - (dtype == DT_INT16 || dtype == DT_INT8 || dtype == DT_UINT8 ? DT_INT32 - : dtype); - nvinfer1::DataType trt_dtype; - TF_RETURN_IF_ERROR(ConvertDType(converted_dtype, &trt_dtype)); // Create shaped weights as output const auto& tensor_proto = node_def.attr().at("value").tensor(); @@ -2733,78 +2803,18 @@ tensorflow::Status ConvertConst(OpConverterParams* params) { node_def.name()); } - TRT_ShapedWeights weights(converted_dtype); - if (tensor.NumElements() == 0) { - // Do nothing. - } else if (!tensor_proto.float_val().empty()) { - TF_RETURN_IF_ERROR(TfTensorToTrtWeights( - converted_dtype, tensor, tensor_proto.float_val().begin(), - tensor_proto.float_val_size(), params->weight_store, &weights)); - } else if (!tensor_proto.int_val().empty()) { - TF_RETURN_IF_ERROR(TfTensorToTrtWeights( - converted_dtype, tensor, tensor_proto.int_val().begin(), - tensor_proto.int_val_size(), params->weight_store, &weights)); - } else if (!tensor_proto.half_val().empty()) { - // TODO(aaroey): implement fp16 conversion. - return errors::Unimplemented("fp16 constant is not supported yet."); - } else if (!tensor_proto.tensor_content().empty()) { - // TODO(aaroey): fp16 will remain in half format and is not converted to - // fp32, but the converter currently uses all float weights as fp32. Fix - // this. - const auto& content = tensor_proto.tensor_content(); - if (content.size() > 0) { - const int dtype_size = tensorflow::DataTypeSize(dtype); - if (content.size() % dtype_size != 0) { - return errors::FailedPrecondition("Tensor content size ", - content.size(), - " is not a multiple of ", dtype_size); - } - nvinfer1::Dims weights_dim; - TF_RETURN_IF_ERROR(GetTensorDimsWithProtoShape( - tensor, content.size() / dtype_size, &weights_dim)); - const int64_t size_bytes = TrtDimsNumElements(weights_dim) * dtype_size; - if (content.size() != size_bytes) { - return errors::FailedPrecondition( - "Tensor size and TensorProto content size mismatch: ", size_bytes, - " vs ", content.size()); - } else if (tensor.NumElements() != content.size() / dtype_size) { - return errors::FailedPrecondition( - "Tensor elements count and TensorProto content size mismatch: ", - tensor.NumElements(), " vs ", content.size() / dtype_size); - } - weights = - params->weight_store->GetTempWeights(converted_dtype, weights_dim); - if (dtype_size == tensorflow::DataTypeSize(converted_dtype)) { - port::CopyToArray(content, static_cast( - const_cast(weights.GetValues()))); - } else { - // Copy out the weights as original data type. - std::vector temp_weights(content.size()); - port::CopyToArray(content, - reinterpret_cast(temp_weights.data())); - int32* dst = - static_cast(const_cast(weights.GetValues())); - // Copy to the weight store as converted data type. - if (dtype == DT_INT16) { - int16* data = reinterpret_cast(temp_weights.data()); - std::copy(data, data + tensor.NumElements(), dst); - } else if (dtype == DT_INT8) { - int8* data = reinterpret_cast(temp_weights.data()); - std::copy(data, data + tensor.NumElements(), dst); - } else if (dtype == DT_UINT8) { - uint8* data = reinterpret_cast(temp_weights.data()); - std::copy(data, data + tensor.NumElements(), dst); - } else { - return errors::FailedPrecondition( - "Unexpected data type: ", DataTypeString(dtype), - " at: ", node_def.name()); - } - } - } - } else { - return errors::Unimplemented("Not supported constant type, at ", - node_def.name()); + TFAttrs attrs(node_def); + const DataType dtype = attrs.get("dtype"); + if (dtype != tensor.dtype()) { + return errors::InvalidArgument("DataType mismatch between attr (", + DataTypeString(dtype), ") and tensor (", + DataTypeString(tensor.dtype()), ")"); } + + TRT_ShapedWeights weights; + TF_RETURN_IF_ERROR( + TfTensorToTrtWeights(tensor, params->weight_store, &weights)); + if (params->outputs != nullptr) { params->outputs->push_back(TRT_TensorOrWeights(weights)); } @@ -2947,18 +2957,15 @@ tensorflow::Status ConvertSquare(OpConverterParams* params) { auto weights_ptr = static_cast(const_cast(weights.GetValues())); weights_ptr[0] = 2.f; - nvinfer1::IConstantLayer* const2_layer = - params->converter->network()->addConstant(dims, weights.GetTrtWeights()); - TFTRT_RETURN_ERROR_IF_NULLPTR(const2_layer, node_def.name()); + nvinfer1::ITensor* const2_tensor = + params->converter->CreateConstantLayer(weights, dims); + TFTRT_RETURN_ERROR_IF_NULLPTR(const2_tensor, node_def.name()); // ElementWise Pow Operation - const nvinfer1::ITensor* tensor_l = inputs.at(0).tensor(); - const nvinfer1::ITensor* tensor_r = const2_layer->getOutput(0); nvinfer1::IElementWiseLayer* layer = params->converter->network()->addElementWise( - *const_cast(tensor_l), - *const_cast(tensor_r), - nvinfer1::ElementWiseOperation::kPOW); + *const_cast(inputs.at(0).tensor()), + *const2_tensor, nvinfer1::ElementWiseOperation::kPOW); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); @@ -3418,7 +3425,6 @@ tensorflow::Status ConvertMatMul(OpConverterParams* params) { } TFAttrs attrs(node_def); - // TODO(jie): INT32 should be converted? tensorflow::DataType tf_dtype = attrs.get("T"); if (tf_dtype != DataType::DT_FLOAT && tf_dtype != DataType::DT_HALF) { return errors::Unimplemented("Data type is not supported, for node ", @@ -3444,7 +3450,6 @@ tensorflow::Status ConvertBatchMatMul(OpConverterParams* params) { const auto& node_def = params->node_def; TFAttrs attrs(node_def); - // TODO(jie): INT32 should be converted? tensorflow::DataType tf_dtype = attrs.get("T"); if (tf_dtype != tensorflow::DataType::DT_FLOAT && tf_dtype != tensorflow::DataType::DT_HALF) { @@ -3566,6 +3571,9 @@ tensorflow::Status ConvertTopK(OpConverterParams* params) { nvinfer1::ITensor* output_value_tensor = layer->getOutput(0); nvinfer1::ITensor* output_indices_tensor = layer->getOutput(1); + // Tensor type for network output is not inferred. Indices should be INT32 + // (default is float). + output_indices_tensor->setType(nvinfer1::DataType::kINT32); params->outputs->push_back(TRT_TensorOrWeights(output_value_tensor)); params->outputs->push_back(TRT_TensorOrWeights(output_indices_tensor)); return tensorflow::Status::OK(); @@ -3686,7 +3694,7 @@ tensorflow::Status ConvertGraphDefToEngine( if (tensorflow::str_util::StartsWith(node_name, kInputPHName) && (node_def.op() == "Placeholder")) { int32 slot_number = -1; - if (!tensorflow::strings::safe_strto32( + if (!tensorflow::strings::safe_strto32( // non-absl ok node_name.c_str() + strlen(kInputPHName), &slot_number)) { return tensorflow::errors::InvalidArgument( "Failed to parse slot number from ", node_name); @@ -3715,7 +3723,7 @@ tensorflow::Status ConvertGraphDefToEngine( } else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) && (node_def.op() == "Identity")) { int32 slot_number = -1; - if (!tensorflow::strings::safe_strto32( + if (!tensorflow::strings::safe_strto32( // non-absl ok node_name.c_str() + strlen(kOutputPHName), &slot_number)) { return tensorflow::errors::InvalidArgument( "Failed to parse slot number from ", node_name); @@ -3749,8 +3757,7 @@ tensorflow::Status ConvertGraphDefToEngine( tensorflow::Status ConvertSegmentToGraphDef( const tensorflow::Graph* graph, const tensorflow::grappler::GraphProperties& graph_properties, - const std::set& subgraph_node_names, - const std::vector& subgraph_node_ids, // In topological order + const std::vector& subgraph_nodes, // In topological order std::vector* connections, tensorflow::GraphDef* segment_def, string* common_scope) { std::set marker_nodes; @@ -3813,8 +3820,10 @@ tensorflow::Status ConvertSegmentToGraphDef( marker_nodes.insert(node_name); auto seg_node = segment_def->add_node(); tensorflow::NodeDefBuilder builder(node_name, "Identity"); - auto status = builder.Input(connection.inside_node_name, 0, dtype) - .Finalize(seg_node); + auto status = + builder + .Input(connection.inside_node_name, connection.inside_port, dtype) + .Finalize(seg_node); VLOG(1) << "Constructing output " << node_name << " for the edge " << connection.inside_node_name << ":" << connection.inside_port << " -> " << connection.outside_node_name << ":" @@ -3824,11 +3833,10 @@ tensorflow::Status ConvertSegmentToGraphDef( std::unordered_map old_to_new_id_map; // Copy internal nodes to new graphdef - string local_scope = graph->FindNodeId(*subgraph_node_ids.begin())->name(); - for (const auto node_id : subgraph_node_ids) { - const auto node = graph->FindNodeId(node_id); + string local_scope = subgraph_nodes.front()->name(); + for (const Node* node : subgraph_nodes) { local_scope = GetCommonNameScope(local_scope, node->name()); - old_to_new_id_map[node_id] = segment_def->node_size(); + old_to_new_id_map[node->id()] = segment_def->node_size(); auto snode = segment_def->add_node(); snode->CopyFrom(node->def()); VLOG(2) << "Copying " << snode->name() << " to subgraph"; @@ -3846,6 +3854,11 @@ tensorflow::Status ConvertSegmentToGraphDef( << placeholder_name; snode->set_input(connection.inside_port, placeholder_name); } + std::set subgraph_node_names; + for (const Node* node : subgraph_nodes) { + subgraph_node_names.insert(node->name()); + } + // Remove control inputs that are not inside the segment. for (int i = 0; i < segment_def->node_size(); ++i) { auto snode = segment_def->mutable_node(i); diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h similarity index 94% rename from tensorflow/contrib/tensorrt/convert/convert_nodes.h rename to tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 54e19b73957bccdae2b23bd3556de9ad00b864e5..aebc0ca38de449dd716b3948f9a0b2e581fc8c80 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_ #include #include @@ -22,11 +22,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" -#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/grappler/costs/graph_properties.h" @@ -128,8 +128,7 @@ struct EngineInfo { tensorflow::Status ConvertSegmentToGraphDef( const tensorflow::Graph* graph, const tensorflow::grappler::GraphProperties& graph_properties, - const std::set& subgraph_node_names, - const std::vector& subgraph_node_ids, + const std::vector& subgraph_nodes, std::vector* connections, tensorflow::GraphDef* segment_def, string* common_scope); @@ -159,7 +158,10 @@ class OutputEdgeValidator { bool operator()(const tensorflow::Edge* out_edge) const; }; +string DebugString(const nvinfer1::DimensionType type); +string DebugString(const nvinfer1::DataType trt_dtype); string DebugString(const nvinfer1::Dims& dims); +string DebugString(const nvinfer1::Permutation& permutation, int len); string DebugString(const nvinfer1::ITensor& tensor); int64_t TrtDimsNumElements(const nvinfer1::Dims& dims); @@ -195,6 +197,10 @@ class TRT_ShapedWeights { // underlying buffer. TRT_ShapedWeights(DataType type, nvinfer1::Dims dims, Tensor tensor); + // All weights should be stored inside TrtWeightStore to make sure lifetime of + // all the underlying tensors are available until the engine is built. For + // this reason, tensor_ should never be reassigned to a different value that + // is not already present in the TrtWeightStore. Tensor tensor_; friend class TrtWeightStore; @@ -469,6 +475,11 @@ class Converter { nvinfer1::Dims* operand_l_new_dims, nvinfer1::Dims* operand_r_new_dims) const; + // Creates an IConstantLayer using 'weights' whose dimensions are specified by + // 'dims', and returns the output ITensor. + nvinfer1::ITensor* CreateConstantLayer(const TRT_ShapedWeights& weights, + const nvinfer1::Dims& dims); + private: // Verify the provided batch_size is consistent with batch_size_ and update it // if necessary. @@ -544,4 +555,4 @@ class Converter { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_ diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc similarity index 89% rename from tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc rename to tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index a2ddfbffa5b0d8c421bcfe054097a9e42b79fe8f..3a70423d12b35e46d2709dcdc25920a3143f41c4 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include #include @@ -21,11 +21,12 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" // NOLINT @@ -36,6 +37,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/config.pb.h" // NOLINT #include "tensorflow/core/public/session.h" @@ -50,7 +52,7 @@ namespace tensorflow { namespace tensorrt { namespace convert { -using ::tensorflow::strings::StrCat; +using absl::StrCat; using ::testing::ElementsAre; using ::testing::ElementsAreArray; @@ -364,9 +366,6 @@ TEST(TRT_TensorOrWeights_Test, Basic) { EXPECT_EQ(false, ptr->is_tensor()); EXPECT_EQ(true, ptr->is_weights()); EXPECT_TRUE(TrtShapedWeightsEquals(weights, ptr->weights())); - - nvinfer1::Dims dims; - dims.nbDims = 0; ExpectTrtDimsEqualsArray({}, ptr->GetTrtDims()); } } @@ -915,6 +914,20 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) { "(tensor #dims 4 vs broadcast #dims 5)"); } +TEST_F(ConverterTest, CreateConstantLayer) { + for (auto dtype : {DT_FLOAT, DT_INT32}) { + TRT_ShapedWeights weights = + weight_store_->GetTempWeights(dtype, GetTestDims({2, 3, 5})); + nvinfer1::ITensor* tensor = + converter_->CreateConstantLayer(weights, GetTestDims({3, 10})); + ASSERT_NE(nullptr, tensor); + EXPECT_EQ(TfDataTypeToTrt(dtype), tensor->getType()) + << "Expected " << DebugString(TfDataTypeToTrt(dtype)) << " vs. actual " + << DebugString(tensor->getType()); + ExpectTrtDimsEqualsArray({3, 10}, tensor->getDimensions()); + } +} + // Class to test various op converters, using both a TrtNodeValidator and // Converter. class OpConverterTest : public ::testing::Test { @@ -1111,6 +1124,30 @@ class OpConverterTest : public ::testing::Test { std::unordered_map validator_inputs_; }; +template +void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField* out) { + out->Clear(); + if (tensor.NumElements() == 0) return; + + // TensorProto does not need to have all the elements present and can truncate + // trailing elements with the same value for compressed representation. Such + // elements are derived based on the tensor shape. + const auto flat = tensor.flat(); + int64 last_index = 0; + for (int64 i = 0; i < tensor.NumElements(); ++i) { + if (flat(i) != flat(last_index)) { + last_index = i; + } + } + + int num_out_elements = last_index + 1; + out->Reserve(num_out_elements); + out->AddNAlreadyReserved(num_out_elements); + const T* src = flat.data(); + T* dst = out->mutable_data(); + std::copy(src, src + num_out_elements, dst); +} + template void TestConvertConst(OpConverterTest* test) { NodeDef node_def; @@ -1123,11 +1160,23 @@ void TestConvertConst(OpConverterTest* test) { const std::vector& expected_value) { test->Reset(); - auto& attr = *node_def.mutable_attr(); + TensorProto* tensor_attr = + (*node_def.mutable_attr())["value"].mutable_tensor(); + tensor_attr->Clear(); + if (as_tensor_content) { - tensor.AsProtoTensorContent(attr["value"].mutable_tensor()); + tensor.AsProtoTensorContent(tensor_attr); } else { - tensor.AsProtoField(attr["value"].mutable_tensor()); + tensor.shape().AsProto(tensor_attr->mutable_tensor_shape()); + tensor_attr->set_dtype(tensor.dtype()); + + if (tensor.dtype() == DT_FLOAT) { + CopyTensorElements(tensor, tensor_attr->mutable_float_val()); + } else if (tensor.dtype() == DT_INT32) { + CopyTensorElements(tensor, tensor_attr->mutable_int_val()); + } else { + tensor.AsProtoField(tensor_attr); + } } test->RunValidationAndConversion(node_def); TRT_TensorOrWeights output; @@ -1140,8 +1189,7 @@ void TestConvertConst(OpConverterTest* test) { { // By default empty tensor will pick DT_FLOAT as data type and we fix it // here. - attr["value"].mutable_tensor()->set_dtype(dtype); - Tensor t; // Empty tensor. + Tensor t(dtype); // Empty tensor. reset_and_test(t, false, {}, {}); } { @@ -1160,6 +1208,22 @@ void TestConvertConst(OpConverterTest* test) { reset_and_test(t, false, {2, 3}, {1, 2, 3, 4, 5, 6}); reset_and_test(t, true, {2, 3}, {1, 2, 3, 4, 5, 6}); } + { + // Set all tensor elements to the same value. Such tensors are encoded + // using a single element list in tensor proto. + Tensor t = ::tensorflow::test::AsTensor({1, 1, 1, 1, 1, 1}, + TensorShape({2, 3})); + reset_and_test(t, false, {2, 3}, {1, 1, 1, 1, 1, 1}); + reset_and_test(t, true, {2, 3}, {1, 1, 1, 1, 1, 1}); + } + { + // Set trailing tensor elements to the same value. Such tensors are + // encoded by truncating all equal elements except the first one. + Tensor t = ::tensorflow::test::AsTensor({2, 2, 1, 1, 1, 1}, + TensorShape({2, 3})); + reset_and_test(t, false, {2, 3}, {2, 2, 1, 1, 1, 1}); + reset_and_test(t, true, {2, 3}, {2, 2, 1, 1, 1, 1}); + } } TEST_F(OpConverterTest, ConvertConst) { @@ -2253,7 +2317,7 @@ TEST_F(OpConverterTest, ConvertSqueeze) { Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); ops::Squeeze::Attrs squeeze_attrs; - squeeze_attrs.axis_ = gtl::ArraySlice(axis); + squeeze_attrs.axis_ = gtl::ArraySlice(axis); // non-absl ok auto squeeze = ops::Squeeze(s.WithOpName("my_squeeze"), input, squeeze_attrs); return squeeze.operation.node()->def(); @@ -2378,6 +2442,8 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { }; { + // Input is weights, should fail. + Reset(); NodeDef node_def = get_strided_slice_nodedef(); AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); AddTestWeights("begin", {4}, {0, 0, 0, 0}); @@ -2619,6 +2685,240 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { } } +TEST_F(OpConverterTest, ConvertConv2D) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_conv2d", "Conv2D", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Two inputs are expected for Conv2D, at my_conv2d"); + } + + // Get nodedef for Conv2D layer. + auto get_conv2d_nodedef = + [](std::vector strides = {1, 1, 1, 1}, string padding = "SAME", + string data_format = "NCHW", + std::vector dilations = {1, 1, 1, 1}) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto filter = ops::Placeholder(s.WithOpName("weights"), DT_FLOAT); + ops::Conv2D::Attrs attrs = + ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations); + auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, strides, + padding, attrs); + return conv2d.operation.node()->def(); + }; + + { + // Input is weights, should fail. + Reset(); + NodeDef node_def = get_conv2d_nodedef(); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Conv2D is only implemented for tensors, not weights, at my_conv2d"); + } + { + // Filter is tensor, should fail. + Reset(); + NodeDef node_def = get_conv2d_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("weights", {3, 3, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Kernel for Conv2D must be constant weights, at my_conv2d"); + } + { + // Filter is not 4D, should fail. + Reset(); + NodeDef node_def = get_conv2d_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Conv2D expects kernel of dimension 4, at my_conv2d"); + } + { + // Dilations is not 4D, should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 1, 1}); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Convolution dilations field must specify 4 dimensions, at my_conv2d"); + } + { + // Dilation value is not 1 for channel, should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 2, 1, 1}); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Dilation rate must be 1 for batch and channel " + "dimensions, at my_conv2d"); + } + { + // Dilation value is not 1 for channel (NHWC), should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 1, 2}); + AddTestTensor("input", {2, 3, 1}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Dilation rate must be 1 for batch and channel " + "dimensions, at my_conv2d"); + } + { + // Strides is not 4D, should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1}); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Convolution strides field must specify 4 dimensions, at my_conv2d"); + } + { + // Stride value is not 1 for channel, should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 2, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1}); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Stride must be 1 for batch and channel dimensions, at my_conv2d"); + } + + struct TestParams { + TestParams(const std::vector& input_dims, + const std::vector& input, + const std::vector& filter_dims, + const std::vector& filter, + const std::vector& strides, const string& padding, + const string& data_format, const std::vector& dilations, + const std::vector& expected_output_dims, + const std::vector& expected_output) + : input_dims(input_dims), + input(input), + filter_dims(filter_dims), + filter(filter), + strides(strides), + padding(padding), + data_format(data_format), + dilations(dilations), + expected_output_dims(expected_output_dims), + expected_output(expected_output) {} + + std::vector input_dims; + std::vector input; + std::vector filter_dims; + std::vector filter; + std::vector strides; + string padding; + string data_format; + std::vector dilations; + std::vector expected_output_dims; + std::vector expected_output; + }; + + // Ok. + const int kConv2DOKCases = 6; + TestParams ok_params[kConv2DOKCases] = { + // Basic + TestParams{/*input_dims=*/{1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{1, 2, 2}, + /*expected_output=*/{1, 1, 0, 1}}, + // SAME padding (Asymmetric) + TestParams{/*input_dims=*/{1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"SAME", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{1, 2, 3}, + /*expected_output=*/{1, 1, -2, 0, 1, -4}}, + // SAME padding (Symmetric) + TestParams{/*input_dims=*/{1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 3, 1, 1}, + /*filter=*/{-1, 0, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"SAME", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{1, 2, 3}, + /*expected_output=*/{1, 2, -1, 3, 1, -3}}, + // NHWC + TestParams{/*input_dims=*/{2, 3, 1}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NHWC", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{2, 2, 1}, + /*expected_output=*/{1, 1, 0, 1}}, + // Dilated + TestParams{/*input_dims=*/{1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 2}, + /*expected_output_dims=*/{1, 2, 1}, + /*expected_output=*/{2, 1}}, + // Strided + TestParams{/*input_dims=*/{1, 2, 4}, + /*input=*/{0, 1, 2, 2, 3, 4, 4, 7}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 2}, + /*padding=*/"VALID", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{1, 2, 2}, + /*expected_output=*/{1, 0, 1, 3}}, + }; + + for (int i = 0; i < kConv2DOKCases; i++) { + Reset(); + NodeDef node_def = + get_conv2d_nodedef(ok_params[i].strides, ok_params[i].padding, + ok_params[i].data_format, ok_params[i].dilations); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("weights", ok_params[i].filter_dims, + ok_params[i].filter); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_conv2d", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + std::vector output_data(ok_params[i].expected_output.size()); + BuildAndRun({{"input", ok_params[i].input}}, "my_conv2d", + &output_data); + EXPECT_THAT(output_data, ElementsAreArray(ok_params[i].expected_output)); + } +} + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc similarity index 96% rename from tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc rename to tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc index d57f2300f8e6e6ce79c538133da6bc5cf5ead2f5..ebf8df1349363e9986020ea705b32edfef43bc93 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc @@ -12,9 +12,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h" -#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" -#include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h" + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" @@ -30,9 +32,9 @@ namespace tensorflow { namespace tensorrt { namespace convert { // TODO(sami): Remove VLOG messages once the code matures +using absl::StrAppend; +using absl::StrCat; using tensorflow::str_util::Uppercase; -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; tensorflow::Status TRTOptimizationPass::Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) { @@ -243,7 +245,7 @@ tensorflow::Status TRTOptimizationPass::Optimize( // If the last token is not an integer, it must be part of the name. // Otherwise it is port number. if (tokens.size() > 1 && - !strings::safe_strto32(tokens.back(), &dumm_port)) { + !strings::safe_strto32(tokens.back(), &dumm_port)) { // non-absl ok StrAppend(&s, ":", tokens.back()); } nodes_to_preserve.push_back(s); diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h similarity index 91% rename from tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h rename to tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h index 3e8dc0978e43e2e9ba07aaa09f74acfe8e59b9a7..bd6c6dbce1ddb8757227a1c71408770ee8be48d8 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ #include @@ -77,4 +77,4 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { #endif // GOOGLE_CUDA #endif // GOOGLE_TENSORRT -#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ diff --git a/tensorflow/contrib/tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc similarity index 97% rename from tensorflow/contrib/tensorrt/convert/utils.cc rename to tensorflow/compiler/tf2tensorrt/convert/utils.cc index e7a1febb8c076891596741fe30721e7acca15a73..62a0f62ad6657f2d1551cd093f4f2d93c25f4cae 100644 --- a/tensorflow/contrib/tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/contrib/tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h similarity index 88% rename from tensorflow/contrib/tensorrt/convert/utils.h rename to tensorflow/compiler/tf2tensorrt/convert/utils.h index 0592f31462af2b20f3a13fe5119e89c2ba42dd8a..9f9ee59087d461bdc825346d9adc976e42f47c5e 100644 --- a/tensorflow/contrib/tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_ #include @@ -47,4 +47,4 @@ Status GetPrecisionMode(const string& name, int* precision_mode); } // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_ diff --git a/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..eae1f8e7525f1816d1c50072ebe4ba6713c96e47 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op.cc @@ -0,0 +1,73 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_KERNELS_GET_SERIALIZED_RESOURCE_OP_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_KERNELS_GET_SERIALIZED_RESOURCE_OP_H_ + +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/refcount.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +class GetSerializedResourceOp : public OpKernel { + public: + explicit GetSerializedResourceOp(OpKernelConstruction* context) + : OpKernel(context) {} + + ~GetSerializedResourceOp() override {} + + void Compute(OpKernelContext* context) override { + // TODO(laigd): it will allocate the tensor on the device and copy the + // serialized string to that tensor, and later sess.run() will copy it back + // to host. We need to optimize this. + const string& container = context->input(0).scalar()(); + const string& resource_name = context->input(1).scalar()(); + + // Get the resource. + SerializableResourceBase* resource = nullptr; + OP_REQUIRES_OK(context, context->resource_manager()->Lookup( + container, resource_name, &resource)); + ::tensorflow::core::ScopedUnref sc(resource); + + // Serialize the resource as output. + string serialized_resource; + OP_REQUIRES_OK(context, resource->SerializeToString(&serialized_resource)); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({}), &output)); + output->scalar()() = serialized_resource; + } +}; + +REGISTER_KERNEL_BUILDER(Name("GetSerializedResourceOp").Device(DEVICE_GPU), + GetSerializedResourceOp); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_KERNELS_GET_SERIALIZED_RESOURCE_OP_H_ diff --git a/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ec038ebda073c8050321d5668b15a2c6faa72a4b --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc @@ -0,0 +1,80 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +class GetSerializedResourceOpTest : public OpsTestBase {}; + +TEST_F(GetSerializedResourceOpTest, Basic) { + // Create the GPU device. + std::unique_ptr device( + DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0")); + + // Create the resource. + class MySerializableResource : public SerializableResourceBase { + public: + string DebugString() const override { return ""; } + Status SerializeToString(string* serialized) override { + *serialized = "my_serialized_str"; + return Status::OK(); + } + }; + const string container = "mycontainer"; + const string resource_name = "myresource"; + SerializableResourceBase* resource = new MySerializableResource(); + ResourceMgr* rm = device->resource_manager(); + EXPECT_TRUE(rm->Create(container, resource_name, resource).ok()); + + // Create the op. + SetDevice(DEVICE_GPU, std::move(device)); + TF_ASSERT_OK(NodeDefBuilder("op", "GetSerializedResourceOp") + .Input(FakeInput(DT_STRING)) + .Input(FakeInput(DT_STRING)) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + + // Execute the op. + AddInputFromArray(TensorShape({}), {container}); + AddInputFromArray(TensorShape({}), {resource_name}); + TF_ASSERT_OK(RunOpKernel()); + + // Verify the result. + // TODO(laigd): OpsTestBase::GetOutput() doesn't work. + Tensor* output = context_->mutable_output(0); + EXPECT_EQ("my_serialized_str", output->scalar()()); +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc similarity index 67% rename from tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc rename to tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index bad568644bb1f8d01d4cb0a7c853ec47d6f19e45..198d68b60985d2b3f2ef958c4f13f94054d4875a 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -12,17 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" +#include "tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.h" #include -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" -#include "tensorflow/contrib/tensorrt/test/utils.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -38,9 +40,9 @@ limitations under the License. namespace tensorflow { namespace tensorrt { static Logger logger; +using absl::StrAppend; +using absl::StrCat; using ::nvinfer1::IRuntime; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; // A helper class to call done() when destructed for asynchronous execution. // Helps simultaneous execution of native and TRT engines. @@ -135,8 +137,6 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) native_func_ = tensorflow::kInvalidHandle; OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count", &max_cached_engines_)); - OP_REQUIRES_OK(context, - context->GetAttr("fixed_input_size", &fixed_input_size_)); OP_REQUIRES_OK(context, context->GetAttr("cached_engine_batches", &cached_engine_batches_)); std::sort(cached_engine_batches_.begin(), cached_engine_batches_.end()); @@ -175,11 +175,13 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, lib->Run(opts, native_func_, inputs, outputs, [this, ctx, outputs, helper](const tensorflow::Status& s) { tensorflow::core::ScopedUnref sc(helper); - VLOG(1) << "Native Segment completed"; if (!s.ok()) { + LOG(ERROR) << "Failed to execute native segment " << this->name() + << ": " << s; ctx->SetStatus(s); return; } + VLOG(1) << "Native Segment completed"; for (size_t t = 0; t < outputs->size(); ++t) { ctx->set_output(t, outputs->at(t)); } @@ -194,18 +196,37 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, VLOG(1) << "Executing TRT calibration: " << name(); helper->Ref(); tensorflow::core::ScopedUnref sc(helper); - // TODO(aaroey): remove the ResourceMgr singleton. - auto trt_rm = TRTResourceManager::instance(); - auto res_mgr = trt_rm->getManager("TRTCalibration"); + auto res_mgr = ctx->resource_manager(); TRTCalibrationResource* calib_res = nullptr; - auto status = res_mgr->LookupOrCreate( - funcdef_name_, "Calibrator", &calib_res, - {[ctx, this](TRTCalibrationResource** cr) -> tensorflow::Status { - return this->AllocateCalibrationResources(ctx, cr); - }}); - if (!status.ok()) { - ctx->SetStatus(status); - return; + OP_REQUIRES_OK( + ctx, + res_mgr->LookupOrCreate( + "TF_TRT_Calibration", name(), + reinterpret_cast(&calib_res), + {[ctx, this](SerializableResourceBase** cr) -> tensorflow::Status { + return this->AllocateCalibrationResources(ctx, cr); + }})); + tensorflow::core::ScopedUnref calib_sc(calib_res); + // TODO(aaroey): here we also add the resource to the ResourceMgr singleton. + // This is needed before we migrate all uses of calib_graph_to_infer_graph() + // to the new calibration workflow. After that we'll remove this block. + { + auto deprecated_rm = + TRTResourceManager::instance()->getManager("TRTCalibration"); + TRTCalibrationResource* copied_resource = nullptr; + // Check whether the resource exists, and create it if not. + if (deprecated_rm->Lookup(funcdef_name_, "Calibrator", &copied_resource) + .ok()) { + // Do nothing if the resource exists. + copied_resource->Unref(); + } else { + copied_resource = calib_res; + // Increase the refcount by 1 then transfer the ownership of that refcount + // to the ResourceMgr singleton. + copied_resource->Ref(); + OP_REQUIRES_OK(ctx, deprecated_rm->Create(funcdef_name_, "Calibrator", + copied_resource)); + } } int num_inputs = ctx->num_inputs(); // Pass input data to calibrator @@ -219,7 +240,8 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, return; } // Check the allocated buffer is sufficient for input - const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); + const auto device_tensor = + calib_res->device_tensors_.at(i).AccessTensor(ctx); CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); input_data.emplace(StrCat(kInputPHName, i), data_address); } @@ -236,32 +258,34 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, ExecuteNativeSegment(ctx, helper); } -int TRTEngineOp::GetEngineBatch(OpKernelContext* ctx) { - int num_batch = ctx->input(0).shape().dim_size(0); - int smallest_engine = 0; - for (const auto i : cached_engine_batches_) { - if (i >= num_batch) { - smallest_engine = i; - break; - } - } - // TODO(sami): Need an LRU here - if (smallest_engine == 0) { - if (max_cached_engines_ > cached_engine_batches_.size()) { - smallest_engine = num_batch; - cached_engine_batches_.push_back(num_batch); - VLOG(1) << "Running with batch size " << num_batch; - } else { - string msg = - StrCat("Engine buffer is full. buffer limit=", max_cached_engines_, - ", current entries="); - for (auto i : cached_engine_batches_) StrAppend(&msg, i, ","); - StrAppend(&msg, " requested batch=", num_batch); - LOG(WARNING) << msg; - return -1; +bool TRTEngineOp::GetCompatibleCachedEngine( + const std::vector& actual_input_shapes, + std::vector* engine_input_shapes) { + const int batch_size = actual_input_shapes[0].dim_size(0); + int smallest_batch_size = -1; + // Output shape will always be the same as the input but we will overwrite the + // batch size. + *engine_input_shapes = actual_input_shapes; + for (const int cached_batch_size : cached_engine_batches_) { + // Check if compatible: batch <= cached batch. + // + // TODO(laigd): here it only compare the first dim a.k.a the batch size, + // we'll need to to support non-batch dimensions as well. This will be done + // as part of the offline conversion implementation. + if (batch_size <= cached_batch_size) { + // First case: first compatible engine found + // Second case: smaller batch size engine found + if ((smallest_batch_size == -1) || + (cached_batch_size < smallest_batch_size)) { + smallest_batch_size = cached_batch_size; + // Overwrite batch size for output + for (int i = 0; i < engine_input_shapes->size(); i++) { + (*engine_input_shapes)[i].set_dim(0, smallest_batch_size); + } + } } } - return smallest_engine; + return (smallest_batch_size != -1); } void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, @@ -272,25 +296,20 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, ExecuteCalibration(ctx, helper); return; } - const int smallest_engine = GetEngineBatch(ctx); - if (smallest_engine < 0) { - LOG(WARNING) << "Failed to get engine batch, running native segment for " - << name(); - ExecuteNativeSegment(ctx, helper); - return; + // Get shapes of inputs to engine. + std::vector input_shapes; + for (int i = 0; i < ctx->num_inputs(); ++i) { + input_shapes.emplace_back(ctx->input(i).shape()); } - - const int num_batch = ctx->input(0).shape().dim_size(0); - auto& engine_ctx_pair = GetEngine(smallest_engine, ctx); - auto& trt_engine_ptr = engine_ctx_pair.first; - if (!trt_engine_ptr) { - LOG(WARNING) << "Engine retrieval for batch size " << num_batch + EngineContext* engine_context = GetEngine(input_shapes, ctx); + if (!engine_context->cuda_engine) { + LOG(WARNING) << "Engine retrieval for input shapes: " + << TensorShapeUtils::ShapeListString(input_shapes) << " failed. Running native segment for " << name(); ExecuteNativeSegment(ctx, helper); return; } - const bool retry = ExecuteTrtEngine(ctx, num_batch, trt_engine_ptr.get(), - engine_ctx_pair.second.get()); + const bool retry = ExecuteTrtEngine(ctx, engine_context); if (retry) { LOG(WARNING) << "Failed to execute engine, " << "retrying with native segment for " << name(); @@ -299,18 +318,19 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, } } -bool TRTEngineOp::ExecuteTrtEngine( - OpKernelContext* ctx, const int num_batch, - nvinfer1::ICudaEngine* trt_engine_ptr, - nvinfer1::IExecutionContext* trt_execution_context_ptr) { +bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, + EngineContext* engine_context) { VLOG(1) << "Executing TRT engine: " << name(); + auto& cuda_engine = engine_context->cuda_engine; const bool kRetry = true; + // All inputs must have the same batch size, so just get it from the first + // input. + const int num_batch = ctx->input(0).shape().dim_size(0); const int num_binding = ctx->num_inputs() + ctx->num_outputs(); std::vector buffers(num_binding); for (int i = 0; i < ctx->num_inputs(); i++) { const string input_name = StrCat(kInputPHName, i); - const int binding_index = - trt_engine_ptr->getBindingIndex(input_name.c_str()); + const int binding_index = cuda_engine->getBindingIndex(input_name.c_str()); if (binding_index == -1) { LOG(ERROR) << "Input node not found, at " << input_name; return kRetry; @@ -323,7 +343,7 @@ bool TRTEngineOp::ExecuteTrtEngine( << " vs " << input_shape.dim_size(0); return kRetry; } - auto dtype = trt_engine_ptr->getBindingDataType(binding_index); + auto dtype = cuda_engine->getBindingDataType(binding_index); switch (dtype) { case nvinfer1::DataType::kFLOAT: buffers[binding_index] = (void*)(input_tensor.flat().data()); @@ -346,13 +366,12 @@ bool TRTEngineOp::ExecuteTrtEngine( for (int i = 0; i < ctx->num_outputs(); i++) { // Create an output tensor const string output_name = StrCat(kOutputPHName, i); - const int binding_index = - trt_engine_ptr->getBindingIndex(output_name.c_str()); + const int binding_index = cuda_engine->getBindingIndex(output_name.c_str()); Tensor* output_tensor = nullptr; TensorShape output_shape; if (binding_index != -1) { - auto dims = trt_engine_ptr->getBindingDimensions(binding_index); + auto dims = cuda_engine->getBindingDimensions(binding_index); std::vector trt_shape(dims.nbDims + 1); trt_shape[0] = num_batch; for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j]; @@ -374,7 +393,7 @@ bool TRTEngineOp::ExecuteTrtEngine( // TODO(aaroey): ideally we should retry, fix this. return !kRetry; } - auto dtype = trt_engine_ptr->getBindingDataType(binding_index); + auto dtype = cuda_engine->getBindingDataType(binding_index); switch (dtype) { case nvinfer1::DataType::kFLOAT: buffers[binding_index] = @@ -402,9 +421,12 @@ bool TRTEngineOp::ExecuteTrtEngine( ->implementation() ->GpuStreamMemberHack())); + // nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex + // for it. + tensorflow::mutex_lock lock(engine_context->mu); // TODO(jie): trt enqueue does not return error - auto ret = trt_execution_context_ptr->enqueue(num_batch, &buffers[0], *stream, - nullptr); + auto ret = engine_context->execution_context->enqueue(num_batch, &buffers[0], + *stream, nullptr); if (!ret) { LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name(); return kRetry; @@ -414,50 +436,45 @@ bool TRTEngineOp::ExecuteTrtEngine( return !kRetry; } -TRTEngineOp::~TRTEngineOp() { - // We need to manually destroy the engine and execution context before - // the allocator is destructed. - for (auto& eng : engine_map_) { - eng.second.first.reset(); - eng.second.second.reset(); +EngineContext* TRTEngineOp::GetEngine( + const std::vector& input_shapes, OpKernelContext* ctx) { + static EngineContext empty_context; + tensorflow::mutex_lock lock(engine_mutex_); + // TODO(tmorris): using first input to get batch size - is this reliable? + const int batch_size = input_shapes[0].dim_size(0); + + // Get engine cache + TRTEngineCacheResource* cache_res = nullptr; + auto status = ctx->resource_manager()->LookupOrCreate( + "TRTEngineCache", funcdef_name_, &cache_res, + {[this, ctx](TRTEngineCacheResource** cr) -> tensorflow::Status { + *cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_); + return Status::OK(); + }}); + if (!status.ok()) { + ctx->SetStatus(status); + return &empty_context; } - allocator_.reset(); -} - -nvinfer1::IGpuAllocator* TRTEngineOp::GetAllocator(OpKernelContext* ctx) { - if (allocator_) return allocator_.get(); - auto device = ctx->device(); - auto alloc = device->GetAllocator(tensorflow::AllocatorAttributes()); - if (!alloc) { - LOG(ERROR) << "Can't find device allocator for gpu device " - << device->name(); - return nullptr; + tensorflow::core::ScopedUnref sc(cache_res); + auto& cache = cache_res->cache_; + auto allocator = cache_res->allocator_.get(); + if (allocator == nullptr) { + return &empty_context; } - allocator_.reset(new TRTDeviceAllocator(alloc)); - return allocator_.get(); -} - -TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, - OpKernelContext* ctx) { - static EngineCtxPair null_pair = { - TrtUniquePtrType(nullptr), - TrtUniquePtrType(nullptr)}; - // TODO(sami): This method needs to be re-written to use resource manager and - // with LRU mechanism option. - tensorflow::mutex_lock lock(engine_mutex_); + // Handle the static engine case. For static engines, the cache will have a + // single element containing the only engine. if (static_engine_) { - if (engine_map_.size()) { - if (engine_map_.begin()->first >= batch_size) { - return engine_map_.begin()->second; + if (cache.size()) { + // Batch size of engine must be >= the input batch size + // TODO(tmorris): use match compatible function? + if (cache.begin()->first[0].dim_size(0) >= batch_size) { + return cache.begin()->second.get(); } - return null_pair; + return &empty_context; } + TrtUniquePtrType infer(nvinfer1::createInferRuntime(logger)); - auto allocator = GetAllocator(ctx); - if (allocator == nullptr) { - return null_pair; - } infer->setGpuAllocator(allocator); TrtUniquePtrType static_engine( infer->deserializeCudaEngine(serialized_segment_.c_str(), @@ -465,62 +482,87 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, PluginFactoryTensorRT::GetInstance())); auto raw_static_engine = static_engine.get(); const auto max_batch_size = raw_static_engine->getMaxBatchSize(); - engine_map_[max_batch_size] = { - std::move(static_engine), - TrtUniquePtrType( - raw_static_engine->createExecutionContext())}; + // Static engine will have max_batch_size for batch size so that all inputs + // will map to this single engine. + std::vector engine_input_shapes(input_shapes); + for (int i = 0; i < engine_input_shapes.size(); i++) { + // TODO(tmorris): will all inputs have batch size as first dimension?? + engine_input_shapes[i].set_dim(0, max_batch_size); + } + // TODO(laigd): here we assume engine_input_shapes matches the actual input + // shapes of the engine, we should verify that. + cache.emplace(engine_input_shapes, + absl::make_unique( + std::move(static_engine), + TrtUniquePtrType( + raw_static_engine->createExecutionContext()))); // Runtime is safe to delete after engine creation serialized_segment_.clear(); if (max_batch_size < batch_size) { - return null_pair; + return &empty_context; } - return engine_map_.at(max_batch_size); + return cache.at(engine_input_shapes).get(); } // static_engine_ // Handle the dynamic engine case. - auto engine_it = engine_map_.find(batch_size); - if (engine_it == engine_map_.end() && - engine_map_.size() < (size_t)max_cached_engines_) { - nvinfer1::IGpuAllocator* allocator = nullptr; - allocator = GetAllocator(ctx); - if (allocator == nullptr) { - return null_pair; - } - std::vector shapes; - for (int i = 0; i < ctx->num_inputs(); ++i) { - shapes.emplace_back(ctx->input(i).shape()); + // See if there is a compatible engine cached. The batch size should be <= the + // cached batch size. + std::vector engine_input_shapes; + const bool matched_successfully = + GetCompatibleCachedEngine(input_shapes, &engine_input_shapes); + // If matched, use that engine. Otherwise, we will look in cache for that + // exact shape and possibly create a new engine if it is not in cache. + if (!matched_successfully) { + engine_input_shapes = input_shapes; + if (!cached_engine_batches_.empty()) { + // If user has explicitly defined cached_engine_batches, we should + // warn them that their input was non-compatible (batch size too high) + LOG(WARNING) << "No compatible cached engine was found for batch size: " + << batch_size << ". A new engine will be created."; + cached_engine_batches_.push_back(batch_size); } + } + + if (!cache.count(engine_input_shapes)) { TrtUniquePtrType engine; bool convert_successfully = false; LOG(INFO) << "Building a new TensorRT engine for " << name() - << " with batch size " << batch_size; + << " input shapes: " + << TensorShapeUtils::ShapeListString(engine_input_shapes); + // Convert to partial shapes + std::vector partial_shapes; + for (int i = 0; i < engine_input_shapes.size(); i++) { + partial_shapes.emplace_back(engine_input_shapes[i]); + } // Up to this point, calibrator_ can never be empty, since otherwise it // means calibration_mode_ is true and this path won't get executed. auto status = convert::ConvertGraphDefToEngine( - segment_graph_, precision_mode_, batch_size, workspace_size_, shapes, - &logger, allocator, calibrator_.get(), &engine, use_calibration_, - &convert_successfully); + segment_graph_, precision_mode_, batch_size, workspace_size_, + partial_shapes, &logger, allocator, calibrator_.get(), &engine, + use_calibration_, &convert_successfully); if (!status.ok()) { if (convert_successfully) { // This means it fail to build the engine even when the network is built // successfully, probably due to internal issues. In this case we don't // retry in the future. - engine_map_[batch_size] = {nullptr, nullptr}; + cache.emplace(engine_input_shapes, absl::make_unique()); } LOG(WARNING) << "Engine creation for batch size " << batch_size << " failed " << status; - return null_pair; + return &empty_context; } VLOG(1) << "Conversion is done"; TrtUniquePtrType exec_context( engine->createExecutionContext()); - engine_map_[batch_size] = {std::move(engine), std::move(exec_context)}; + cache.emplace(engine_input_shapes, + absl::make_unique(std::move(engine), + std::move(exec_context))); } - return engine_map_.at(batch_size); + return cache.at(engine_input_shapes).get(); } tensorflow::Status TRTEngineOp::AllocateCalibrationResources( - OpKernelContext* ctx, TRTCalibrationResource** cr) { + OpKernelContext* ctx, SerializableResourceBase** cr) { auto cres = new TRTCalibrationResource(); *cr = cres; // Get the allocator. @@ -536,7 +578,7 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources( const int batch_size = ctx->input(0).dim_size(0); const int num_inputs = ctx->num_inputs(); std::vector shapes; - dev_tensors_.resize(num_inputs); + cres->device_tensors_.resize(num_inputs); VLOG(1) << " Constructing calibrator"; for (int i = 0; i < num_inputs; i++) { // allocate workspace on device for inputs @@ -544,19 +586,19 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources( shapes.emplace_back(t.shape()); Tensor* device_tensor; TF_RETURN_IF_ERROR(ctx->allocate_persistent( - t.dtype(), t.shape(), &dev_tensors_.at(i), &device_tensor)); + t.dtype(), t.shape(), &cres->device_tensors_.at(i), &device_tensor)); CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); void* device_address = GetTensorAddress(device_tensor); if (device_address == nullptr) { return tensorflow::errors::InvalidArgument( "Unsupported data type encountered in input ", i); } - device_buffers_.emplace( + cres->device_buffers_.emplace( StrCat(kInputPHName, i), std::pair(device_address, device_tensor->TotalBytes())); } cres->calibrator_.reset( - new TRTInt8Calibrator(device_buffers_, batch_size, name())); + new TRTInt8Calibrator(cres->device_buffers_, batch_size, name())); const string label(name()); auto segment_graph = &segment_graph_; const int platform_gpu_id = diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.h similarity index 67% rename from tensorflow/contrib/tensorrt/kernels/trt_engine_op.h rename to tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.h index b545f497f32d5a1a6960b748467ca189b7debf6c..64f8c97a74092ac075de9cc7993283e3ce1e27cf 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.h @@ -13,20 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_KERNELS_TRT_ENGINE_OP_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_KERNELS_TRT_ENGINE_OP_H_ #include #include -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -36,7 +39,6 @@ limitations under the License. namespace tensorflow { namespace tensorrt { struct TRTInt8Calibrator; -class TRTCalibrationResource; class AsyncHelper; // TODO(Sami): Remove this file? @@ -48,9 +50,10 @@ class TRTEngineOp : public AsyncOpKernel { void ComputeAsync(OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; - ~TRTEngineOp(); private: + // TODO(samikama): context should go to a resource manager! + // Execute calibration void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper); @@ -62,33 +65,25 @@ class TRTEngineOp : public AsyncOpKernel { // Execute the tensorrt engine. Returns whether we need to retry by running // the native segment. - bool ExecuteTrtEngine(OpKernelContext* ctx, const int num_batch, - nvinfer1::ICudaEngine* trt_engine_ptr, - nvinfer1::IExecutionContext* trt_execution_context_ptr); + bool ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context); // Allocate necessary resources for calibration Status AllocateCalibrationResources(OpKernelContext* ctx, - TRTCalibrationResource** cr); - - // TODO(samikama): context should go to a resource manager! - typedef std::pair, - TrtUniquePtrType> - EngineCtxPair; - EngineCtxPair& GetEngine(int batch_size, OpKernelContext* ctx); + SerializableResourceBase** cr); - // Return engine batch closest to input batch. - int GetEngineBatch(OpKernelContext* ctx); + // Get engine for the input shape + EngineContext* GetEngine(const std::vector& input_shapes, + OpKernelContext* ctx); - nvinfer1::IGpuAllocator* GetAllocator(OpKernelContext* ctx); + // Return engine batch in cached_engne_batch_sizes_ which is closest to input + // batch. + bool GetCompatibleCachedEngine( + const std::vector& actual_input_shapes, + std::vector* engine_input_shapes); - // map to keep engines and their execution context for given batch size. - std::unordered_map engine_map_; std::vector input_nodes_; std::vector output_nodes_; - // keep device allocator for TRT. - std::unique_ptr allocator_; - // serialized protobuf segment or trt engine depending on static_engine_ flag. string serialized_segment_; @@ -98,12 +93,6 @@ class TRTEngineOp : public AsyncOpKernel { // GraphDef representation of the segment. GraphDef segment_graph_; - // Lookup table for temporary staging areas of input tensors for calibration. - std::unordered_map> device_buffers_; - - // Temporary staging areas for calibration inputs. - std::vector dev_tensors_; - // Engine Precision mode. int precision_mode_; @@ -114,10 +103,6 @@ class TRTEngineOp : public AsyncOpKernel { // Whether to calibrate INT8 engine. bool calibration_mode_; - // Whether non-batch ranks of the inputs are assumed to be fixed or not for - // engine construction. - bool fixed_input_size_; - // Batches of the cached engines std::vector cached_engine_batches_; @@ -142,4 +127,4 @@ class TRTEngineOp : public AsyncOpKernel { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_KERNELS_TRT_ENGINE_OP_H_ diff --git a/tensorflow/compiler/tf2tensorrt/ops/get_serialized_resource_op.cc b/tensorflow/compiler/tf2tensorrt/ops/get_serialized_resource_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..59da73f5efc8eedc20c35cf35cb1eae6cda136c9 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/ops/get_serialized_resource_op.cc @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { + +REGISTER_OP("GetSerializedResourceOp") + .Input("container: string") + .Input("resource_name: string") + .Output("serialized_resource: string") + .SetShapeFn(shape_inference::ScalarShape) + .SetIsStateful() + .Doc(R"doc( +Gets a resource from a container managed by the resource manager and returns +its serialized representation. +)doc"); + +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc similarity index 80% rename from tensorflow/contrib/tensorrt/ops/trt_engine_op.cc rename to tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc index 92405906eb76b043bc08b68e25e16ab40197dddf..b84d2fe0b8cef3475f2a7d0f5383d5e11cde099a 100644 --- a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc @@ -28,16 +28,22 @@ namespace shape_inference { extern Status TRTEngineOpShapeInference(InferenceContext* c); } +// NOTE: please try NOT to add/modify/remove attributes or inputs/outputs to the +// list below, this will break backward compatibility! +// +// TODO(laigd): consider making this op stateful. The only problem is it uses TF +// function which has to be stateless, but we can use function library as the +// key to cache the instantiated functions for different executor subgraphs. REGISTER_OP("TRTEngineOp") .Attr("serialized_segment: string") .Attr("input_shapes: list(shape)") .Attr("output_shapes: list(shape)") .Attr("segment_funcdef_name: string") - .Attr("InT: list({int8,float16,float32})") - .Attr("OutT: list({int8,float16,float32})") + .Attr("InT: list({int8,float16,float32,int32})") + .Attr("OutT: list({int8,float16,float32,int32})") .Attr("static_engine: bool = true") .Attr("fixed_input_size: bool = true") - .Attr("cached_engine_batches: list(int) = []") + .Attr("cached_engine_batches: list(int) >= 0 = []") .Attr("max_cached_engines_count: int = 1") .Attr("workspace_size_bytes: int") .Attr("precision_mode: {'FP32', 'FP16', 'INT8'}") diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc similarity index 96% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin.cc rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc index 062f86e8bb4dc753925e4e2baf0bc80a5312a94f..a4341c530fffca88c82813cc2ace2c0ae1df5345 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" + #include #include -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" + +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h similarity index 92% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin.h rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h index 754920b60ca7439513a91ad0354833a2482b29c1..f495d857037c79a1783f8eb232fb57c20e229169 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_ #include #include @@ -71,4 +71,4 @@ class PluginTensorRT : public nvinfer1::IPlugin { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_ diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.cc similarity index 96% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.cc index cccc91226265ed139fb8db0b71c40b868f729562..871fb1210bd495dc3f5e8153bb6c3a361bf569f5 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h similarity index 91% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h index bbae9fb65c22cf69d2e7954436fd04dd16f7f6c8..9aa99a40b80de92a4d9b9ad36e88e693b8aa42dc 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ #include #include -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" @@ -99,4 +99,4 @@ class TrtPluginRegistrar { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory_test.cc similarity index 96% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory_test.cc index 129bdcdbc2f8d9d5215f45f381bcadf35e4fa75e..7d9c465c22beed0e252cbc26d6c533a0789d4f49 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory_test.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.cc similarity index 94% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.cc index a8f60886c03c174a612e7a135b6eb7bb7cb9997a..f3d6b4ff476139693a5251ddf58a3200d8af8efc 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h" #include #if GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h similarity index 82% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h index 274ce42fec9283c643004d45fba461879fc5f2dc..e5eff15c19694093c7a5ea933a41375e8e01c8b9 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ #include -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" #include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA @@ -43,4 +43,4 @@ string ExtractOpName(const void* serial_data, size_t serial_length, #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ diff --git a/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py similarity index 84% rename from tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py rename to tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py index 31a313182be9a2fca7457a539670dbc911ccabb1..86bfabf99e08a8e447a28504c72eebca4d3a582c 100644 --- a/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py +++ b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py @@ -22,13 +22,13 @@ import platform if platform.system() != "Windows": # pylint: disable=wildcard-import,unused-import,g-import-not-at-top - from tensorflow.contrib.tensorrt.ops.gen_trt_engine_op import * + from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import * - from tensorflow.contrib.util import loader + from tensorflow.python.framework import load_library from tensorflow.python.platform import resource_loader # pylint: enable=wildcard-import,unused-import,g-import-not-at-top - _trt_engine_op = loader.load_op_library( - resource_loader.get_path_to_datafile("_trt_engine_op.so")) + _trt_ops = load_library.load_op_library( + resource_loader.get_path_to_datafile("_trt_ops.so")) else: raise RuntimeError("Windows platforms are not supported") diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc similarity index 97% rename from tensorflow/contrib/tensorrt/segment/segment.cc rename to tensorflow/compiler/tf2tensorrt/segment/segment.cc index 084a96e0fa5c97edc58adf2590ed94e5ef0e4d85..4a8a4ac7589a4b68b129e8e88ee999e8a2495728 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/segment/segment.h" +#include "tensorflow/compiler/tf2tensorrt/segment/segment.h" #include #include #include #include -#include "tensorflow/contrib/tensorrt/segment/union_find.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/segment/union_find.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -32,8 +33,8 @@ limitations under the License. namespace tensorflow { namespace tensorrt { namespace segment { -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; // A simple graph representation to mirror tensorflow::Graph. This structure // helps saving memory since segmenter modifies the graph in place, preventing @@ -673,10 +674,11 @@ tensorflow::Status SegmentGraph( // --------------------------------- Step 3 --------------------------------- // Convert the segments into the expected return format for (const auto& itr : sg_map) { - const std::set& segment_nodes = - itr.second; + const string& segment_root = itr.first; + // Return format does not require set comparator. + std::set segment_nodes(itr.second.begin(), itr.second.end()); if (VLOG_IS_ON(1)) { - string s = "parent=" + itr.first + ":"; + string s = "parent=" + segment_root + ":"; for (auto node : segment_nodes) s += " " + node->name(); VLOG(1) << "Segment " << segments->size() << ": " << s; } @@ -689,12 +691,10 @@ tensorflow::Status SegmentGraph( } // TODO(sami): Make segmenter placement aware once trtscopes are in place - std::set segment_node_names; - for (auto node : itr.second) segment_node_names.insert(node->name()); - const auto& dev_itr = device_maps.find(itr.first); + const auto& dev_itr = device_maps.find(segment_root); if (dev_itr == device_maps.end() || dev_itr->second.empty()) { VLOG(1) << "No device assigned to segment " << segments->size(); - segments->emplace_back(std::make_pair(segment_node_names, string())); + segments->emplace_back(std::make_pair(segment_nodes, string())); } else if (dev_itr->second.size() > 1) { string s("Segment "); StrAppend(&s, segments->size(), " has multiple devices attached: "); @@ -703,10 +703,10 @@ tensorflow::Status SegmentGraph( } LOG(WARNING) << s << " choosing " << *(dev_itr->second.begin()); segments->emplace_back( - std::make_pair(segment_node_names, *(dev_itr->second.begin()))); + std::make_pair(segment_nodes, *(dev_itr->second.begin()))); } else { segments->emplace_back( - std::make_pair(segment_node_names, *(dev_itr->second.begin()))); + std::make_pair(segment_nodes, *(dev_itr->second.begin()))); } } if (VLOG_IS_ON(1)) { diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/compiler/tf2tensorrt/segment/segment.h similarity index 83% rename from tensorflow/contrib/tensorrt/segment/segment.h rename to tensorflow/compiler/tf2tensorrt/segment/segment.h index b9693aad1b764515459db6833b05221ea5b3a2d1..9a0ccc9aef475edfb0ffb83a2be21d4d4ca0e028 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.h +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_SEGMENT_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_SEGMENT_H_ #include #include @@ -29,10 +29,10 @@ namespace tensorflow { namespace tensorrt { namespace segment { -// Vector of segments, each entry contains a set of node names and a device name -// in the segment. -// TODO(aaroey): use node pointer instead of node name. -using SegmentNodesVector = std::vector, string>>; +// Vector of segments, each entry contains a set of node pointers and a device +// name in the segment. +using SegmentNodesVector = + std::vector, string>>; struct SegmentOptions { // Segment must contain at least this many nodes. @@ -60,4 +60,4 @@ tensorflow::Status SegmentGraph( } // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_SEGMENT_H_ diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc similarity index 97% rename from tensorflow/contrib/tensorrt/segment/segment_test.cc rename to tensorflow/compiler/tf2tensorrt/segment/segment_test.cc index 4805ef9c61a7784a1c08cf5eaf504691bc9dbedc..58512d3b09d7c6f523710bc09843c628a5838b53 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/segment/segment.h" +#include "tensorflow/compiler/tf2tensorrt/segment/segment.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -75,7 +75,10 @@ class SegmentTest : public ::testing::Test { const std::vector>& expected_segments) { EXPECT_EQ(expected_segments.size(), segments.size()); for (int i = 0; i < segments.size(); ++i) { - const auto& segment_node_names = segments[i].first; + std::set segment_node_names; + for (const Node* node : segments[i].first) { + segment_node_names.insert(node->name()); + } const auto& expected = expected_segments[i]; for (const auto& name : expected) { EXPECT_TRUE(segment_node_names.count(name)) diff --git a/tensorflow/contrib/tensorrt/segment/union_find.h b/tensorflow/compiler/tf2tensorrt/segment/union_find.h similarity index 92% rename from tensorflow/contrib/tensorrt/segment/union_find.h rename to tensorflow/compiler/tf2tensorrt/segment/union_find.h index 1c64ebbb0ae532a4776ab8963515d19fd3b23b4c..6458ae692fd7c922b5fc3bea2e55b613447dbde0 100644 --- a/tensorflow/contrib/tensorrt/segment/union_find.h +++ b/tensorflow/compiler/tf2tensorrt/segment/union_find.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ namespace tensorflow { namespace tensorrt { @@ -76,4 +76,4 @@ UnionFind* UnionFind::FindRoot() { } // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ diff --git a/tensorflow/contrib/tensorrt/tensorrt_test.cc b/tensorflow/compiler/tf2tensorrt/tensorrt_test.cc similarity index 100% rename from tensorflow/contrib/tensorrt/tensorrt_test.cc rename to tensorflow/compiler/tf2tensorrt/tensorrt_test.cc diff --git a/tensorflow/contrib/tensorrt/test/utils.cc b/tensorflow/compiler/tf2tensorrt/utils/test_utils.cc similarity index 97% rename from tensorflow/contrib/tensorrt/test/utils.cc rename to tensorflow/compiler/tf2tensorrt/utils/test_utils.cc index 276308b3a0a6ce864969afb0179c6a3f00d6b70b..3bcca99afbff8b84d2dd628ae9211ee94e86af2a 100644 --- a/tensorflow/contrib/tensorrt/test/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/test_utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/test/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h" #include #include diff --git a/tensorflow/contrib/tensorrt/test/utils.h b/tensorflow/compiler/tf2tensorrt/utils/test_utils.h similarity index 89% rename from tensorflow/contrib/tensorrt/test/utils.h rename to tensorflow/compiler/tf2tensorrt/utils/test_utils.h index 4bb4120206cfaae70107e55d1818e3af2f02717a..bcd628b62f0320f7ce9dfe6240316d876f1d5a20 100644 --- a/tensorflow/contrib/tensorrt/test/utils.h +++ b/tensorflow/compiler/tf2tensorrt/utils/test_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_ #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -41,4 +41,4 @@ string GetTestValue(const string& label); } // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_ diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc similarity index 98% rename from tensorflow/contrib/tensorrt/resources/trt_allocator.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc index 7a2e93414aed56525eaeac876cdac20404bcf6ab..1636cdc30c4df157ed124b160449af645f917252 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.h b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h similarity index 93% rename from tensorflow/contrib/tensorrt/resources/trt_allocator.h rename to tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h index f857a9de055ee7668f0bf9bc97e030354505081b..59ffb42bad348c78cde32035aff8c7081528b3a6 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ALLOCATOR_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ALLOCATOR_H_ #include @@ -81,4 +81,4 @@ class TRTDeviceAllocator : public TRTBaseAllocator { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ALLOCATOR_H_ diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator_test.cc similarity index 98% rename from tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_allocator_test.cc index beb1284208e4c10ffe1d36ef411cf08f11dbcb78..e457c64928e5df84c7e2726ba3621420f013dbc9 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc similarity index 98% rename from tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc index dab1dd9343be7d5b033a3e04bf0b49fbbf37e9e5..bf111d3a2ee2fbec9151d12bbb6ff7181761c2aa 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" #include #include diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h similarity index 93% rename from tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h rename to tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h index 65466c9741989fda5f82fc27d813d026f35fe386..10587e99624acfb97730bbbd9dfbcde020ffc669 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_INT8_CALIBRATOR_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_INT8_CALIBRATOR_H_ #include #include @@ -96,4 +96,4 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { #endif #endif -#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_INT8_CALIBRATOR_H_ diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc similarity index 96% rename from tensorflow/contrib/tensorrt/log/trt_logger.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc index dda0dc9e712eb726800abfb6084f4f708d04825b..c48bd6bf7747d1646c4e450b780822728e8573f1 100644 --- a/tensorflow/contrib/tensorrt/log/trt_logger.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.h similarity index 86% rename from tensorflow/contrib/tensorrt/log/trt_logger.h rename to tensorflow/compiler/tf2tensorrt/utils/trt_logger.h index 96ccacb791e40143c5c4d9d691bb353702f9a28b..22f4de970a80765b0e1e7e8816134d83aaec7c73 100644 --- a/tensorflow/contrib/tensorrt/log/trt_logger.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LOGGER_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LOGGER_H_ #include "tensorflow/core/platform/types.h" @@ -41,4 +41,4 @@ class Logger : public nvinfer1::ILogger { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LOGGER_H_ diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h new file mode 100644 index 0000000000000000000000000000000000000000..09c47b36b0ad8074e749342e7d08f139da7ea1f4 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h @@ -0,0 +1,192 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_ + +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/errors.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +template +class LRUCache { + public: + typedef Value value_type; + typedef Key key_type; + typedef HashFunction hasher; + typedef typename std::unordered_map map_type; + typedef typename map_type::iterator iterator; + typedef typename map_type::const_iterator const_iterator; + + LRUCache() : capacity_(0) {} + explicit LRUCache(size_t capacity) : capacity_(capacity) {} + + size_t capacity() const { return capacity_; } + + void reserve(size_t capacity) { + capacity_ = capacity; + DiscardOld(); + } + + size_t size() const { return objects_.size(); } + + size_t count(const key_type& key) const { return objects_.count(key); } + + value_type& at(const key_type& key) { return Touch(key); } + + const_iterator begin() const { return objects_.begin(); } + const_iterator end() const { return objects_.end(); } + + iterator begin() { return objects_.begin(); } + iterator end() { return objects_.end(); } + + template + std::pair emplace(Args&&... args) { + DiscardOld(1); + std::pair result = + objects_.emplace(std::forward(args)...); + key_type key = result.first->first; + if (result.second) { + keys_.push_front(key); + } else { + TouchNoCheck(key); // The key must exist in this case. + } + return result; + } + + private: + std::unordered_map objects_; + std::list keys_; + size_t capacity_; + value_type not_found_value_; + + value_type& Touch(const key_type& key) { + // Check that the key exists, and let it return std::out_of_range error if + // not. + value_type& value = objects_.at(key); + TouchNoCheck(key); + return value; + } + + void TouchNoCheck(const key_type& key) { + auto rank = std::find(keys_.begin(), keys_.end(), key); + if (rank != keys_.begin()) { + keys_.erase(rank); + keys_.push_front(key); + } + } + + // Creates n free positions in cache + tensorflow::Status DiscardOld(size_t n = 0) { + if (n > capacity_) { + return tensorflow::errors::Internal( + "Insufficient capacity in cache (capacity = ", capacity_, + ", requested ", n, ")"); + } + while (objects_.size() > (capacity_ - n)) { + key_type discard_key = keys_.back(); + keys_.pop_back(); + objects_.erase(discard_key); + } + return tensorflow::Status::OK(); + } +}; + +// Define a hash function for vector because it is used as the key +// for the engine cache. +struct VectorTensorShapeHasher { + std::size_t operator()( + const std::vector& key) const { + return std::hash()(TensorShapeUtils::ShapeListString(key)); + } +}; + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +struct EngineContext { + EngineContext() {} // Creates an empty context. + EngineContext( + TrtUniquePtrType&& input_cuda_engine, + TrtUniquePtrType&& input_execution_context) + : cuda_engine(std::move(input_cuda_engine)), + execution_context(std::move(input_execution_context)) {} + + mutex mu; + TrtUniquePtrType cuda_engine; + TrtUniquePtrType execution_context + GUARDED_BY(mu); +}; + +class TRTEngineCacheResource : public tensorflow::ResourceBase { + public: + TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity) + : cache_(capacity) { + auto device = ctx->device(); + auto alloc = device->GetAllocator(tensorflow::AllocatorAttributes()); + if (!alloc) { + LOG(ERROR) << "Can't find device allocator for gpu device " + << device->name(); + allocator_ = nullptr; + } else { + allocator_.reset(new TRTDeviceAllocator(alloc)); + } + } + + string DebugString() const override { + std::stringstream oss; + using std::dec; + using std::endl; + using std::hex; + oss << "TRTEngineCacheResource: "; + oss << "TRTBaseAllocator = " << hex << allocator_.get() << dec << ", "; + oss << "LRUCache = " << hex << &cache_ << dec << endl; + oss << "Containing " << cache_.size() << " entries: " << endl; + for (const auto& item : cache_) { + oss << TensorShapeUtils::ShapeListString(item.first) << ": " << hex + << "ICudaEngine: " << item.second.get()->cuda_engine.get() << ", " + << "IExecutionContext: " << item.second.get()->execution_context.get() + << dec << endl; + } + return oss.str(); + } + + // Keep device allocator for TRT. + std::unique_ptr allocator_; + + // Declare cache after allocator so that it is destroyed before allocator is. + LRUCache, std::unique_ptr, + VectorTensorShapeHasher> + cache_; +}; + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_ diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache_test.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0aa5eb8f7d4ad062c2d8622fa5aa55f823f80dd5 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace tensorrt { + +TEST(LRUCacheTest, Basic) { + LRUCache> cache; + cache.reserve(2); + // Insert 10 + cache.emplace(10, 100); + EXPECT_EQ(cache.size(), 1); + EXPECT_EQ(cache.count(10), 1); + EXPECT_EQ(cache.at(10), 100); + EXPECT_EQ(cache.count(100), 0); + // Insert 20 + cache.emplace(20, 200); + EXPECT_EQ(cache.size(), 2); + EXPECT_EQ(cache.count(10), 1); + EXPECT_EQ(cache.count(20), 1); + EXPECT_EQ(cache.at(10), 100); + EXPECT_EQ(cache.at(20), 200); + EXPECT_EQ(cache.count(100), 0); + EXPECT_EQ(cache.count(200), 0); + // Insert 30, Evicting 10 + cache.emplace(30, 300); + EXPECT_EQ(cache.count(10), 0); + EXPECT_EQ(cache.count(20), 1); + EXPECT_EQ(cache.count(30), 1); + // Touch 20 + cache.at(20); + // Insert 40, Evicting 30 + cache.emplace(40, 400); + EXPECT_EQ(cache.count(10), 0); + EXPECT_EQ(cache.count(20), 1); + EXPECT_EQ(cache.count(30), 0); + EXPECT_EQ(cache.count(40), 1); +} + +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.cc similarity index 96% rename from tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.cc index 9c3698e5d1cc5d6d8d31a8fcaf03d103f1e1915d..0a72a88bc740101bcbadb40bfe106a5b8d284bbf 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { diff --git a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h b/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h similarity index 87% rename from tensorflow/contrib/tensorrt/resources/trt_resource_manager.h rename to tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h index 19f39e6d3db1571573fb290dd2c30fd43ea604ef..03879ffff2fa724b05cb1919753e4aaa99e2e702 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCE_MANAGER_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCE_MANAGER_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_ #include #include @@ -42,4 +42,4 @@ class TRTResourceManager { } // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCE_MANAGER_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_ diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_resources.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.cc new file mode 100644 index 0000000000000000000000000000000000000000..37f7fe99fbb2b9e121953fc0de211db1bbf34b7a --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.cc @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +TRTCalibrationResource::~TRTCalibrationResource() { + VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString(); + builder_.reset(); + engine_.reset(); + // We need to manually destroy the builder and engine before the allocator + // is destroyed. + allocator_.reset(); +} + +string TRTCalibrationResource::DebugString() const { + std::stringstream oss; + using std::dec; + using std::endl; + using std::hex; + oss << " Calibrator = " << hex << calibrator_.get() << dec << endl + << " Builder = " << hex << builder_.get() << dec << endl + << " Engine = " << hex << engine_.get() << dec << endl + << " Logger = " << hex << &logger_ << dec << endl + << " Allocator = " << hex << allocator_.get() << dec << endl + << " Thread = " << hex << thr_.get() << dec << endl; + return oss.str(); +} + +Status TRTCalibrationResource::SerializeToString(string* serialized) { + calibrator_->waitAndSetDone(); + thr_->join(); + *serialized = calibrator_->getCalibrationTableAsString(); + if (!serialized->size()) { + return tensorflow::errors::Unknown("Calibration table is empty."); + } + return Status::OK(); +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_resources.h b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.h new file mode 100644 index 0000000000000000000000000000000000000000..5e8d4b3b738df09b0c2ea82dcc06e9b23a708385 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.h @@ -0,0 +1,73 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCES_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCES_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +class SerializableResourceBase : public tensorflow::ResourceBase { + public: + virtual Status SerializeToString(string* serialized) = 0; +}; + +class TRTCalibrationResource : public SerializableResourceBase { + public: + ~TRTCalibrationResource() override; + + string DebugString() const override; + + Status SerializeToString(string* serialized) override; + + // Lookup table for temporary staging areas of input tensors for calibration. + std::unordered_map> device_buffers_; + + // Temporary staging areas for calibration inputs. + std::vector device_tensors_; + + std::unique_ptr calibrator_; + TrtUniquePtrType builder_; + TrtUniquePtrType engine_; + std::unique_ptr allocator_; + tensorflow::tensorrt::Logger logger_; + // TODO(sami): Use threadpool threads! + std::unique_ptr thr_; +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCES_H_ diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index d8123e956fac04912b4fed5bf75cc9cb55c5baf9..92ba474fbcd085e3e33ceea4395cca4034969bd9 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -204,6 +204,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", @@ -224,6 +225,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", ], alwayslink = 1, ) @@ -669,6 +671,7 @@ cc_library( name = "side_effect_util", srcs = ["side_effect_util.cc"], hdrs = ["side_effect_util.h"], + visibility = [":friends"], deps = [ "//tensorflow/core:core_cpu", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/tf2xla/cpu_function_runtime.h b/tensorflow/compiler/tf2xla/cpu_function_runtime.h index dfc1e8b8aebcf3142e9f61f60171c6b58634c71d..78970fb39bae7067c7668baa2aec65732b5b2352 100644 --- a/tensorflow/compiler/tf2xla/cpu_function_runtime.h +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime.h @@ -104,7 +104,7 @@ class BufferInfo { private: BufferInfo() = default; - enum class Kind : unsigned { + enum class Kind : uint64 { kConstant, kTempBuffer, kEntryParameter, diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index efb75749722893100494e089c0beb96944e9f1d4..5e4699bbb6218089d2e76a36c7351bf7fbd23264 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -88,6 +89,9 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, case XlaExpression::Kind::kResource: return errors::Unimplemented( "Resource as function argument is not yet implemented."); + case XlaExpression::Kind::kTensorList: + return errors::Unimplemented( + "TensorList as function argument is not yet implemented."); case XlaExpression::Kind::kInvalid: return errors::InvalidArgument("Invalid function argument"); } @@ -191,6 +195,9 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, // into the functions. XlaOpKernelContext xla_op_context(op_context); + XlaContext& context = XlaContext::Get(op_context); + auto* b = context.builder(); + XlaCompiler* compiler = xla_op_context.compiler(); NameAttrList func; @@ -219,8 +226,12 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, TF_RETURN_IF_ERROR( PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments)); + bool add_token_input_output = + HasNodeAttr(n->def(), kXlaTokenInputNodesAttrName); + XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = false; + compile_options.add_token_input_output = add_token_input_output; XlaCompiler::CompilationResult result; TF_RETURN_IF_ERROR( compiler->CompileFunction(compile_options, func, arguments, &result)); @@ -234,9 +245,19 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, } handles.push_back(expressions[i]->handle()); } - - XlaContext& context = XlaContext::Get(op_context); - auto* b = context.builder(); + if (add_token_input_output) { + std::vector token_input_nodes; + TF_RETURN_IF_ERROR( + GetNodeAttr(n->def(), kXlaTokenInputNodesAttrName, &token_input_nodes)); + std::vector token_inputs; + for (const string& node_name : token_input_nodes) { + auto token_or = compiler->GetNodeToken(node_name); + TF_RETURN_IF_ERROR(token_or.status()); + token_inputs.push_back(token_or.ConsumeValueOrDie()); + } + xla::XlaOp token_input = xla::AfterAll(b, token_inputs); + handles.push_back(token_input); + } auto output_handle = xla::Call(b, *result.computation, handles); // The output handle of `Call` computation is a tuple type. Unzip it so @@ -251,6 +272,10 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, ++computation_output; } } + if (add_token_input_output) { + TF_RETURN_IF_ERROR(compiler->SetNodeToken( + n->name(), xla::GetTupleElement(output_handle, computation_output))); + } return b->first_error(); } diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 47209d285f1a077fd80f779a406e6980892f1646..52d2901e73d16f71ecbf633ede0d2cf553b6e521 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -144,13 +144,22 @@ tf_kernel_library( "//tensorflow/compiler/xla/client/lib:quantize", "//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/compiler/xla/client/lib:triangular_solve", + "//tensorflow/core:bitwise_ops_op_lib", + "//tensorflow/core:control_flow_ops_op_lib", "//tensorflow/core:framework", + "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", "//tensorflow/core:linalg_ops_op_lib", + "//tensorflow/core:list_ops_op_lib", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:random_ops_op_lib", + "//tensorflow/core:resource_variable_ops_op_lib", "//tensorflow/core:spectral_ops_op_lib", "//tensorflow/core:stateless_random_ops_op_lib", + "//tensorflow/core:training_ops_op_lib", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:concat_lib", "//tensorflow/core/kernels:constant_op", diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 46e5d68c78fd9ff26a88dc2a1484c3a67b76f4f3..6b675fa8a94e0bc932baaa359565cbc8e4614ee5 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -39,7 +39,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, OP_REQUIRES( ctx, - xla::ShapeUtil::Rank(crops.shape()) == 2 && + crops.shape().rank() == 2 && block_rank == xla::ShapeUtil::GetDimension(crops.shape(), 0) && 2 == xla::ShapeUtil::GetDimension(crops.shape(), 1), errors::InvalidArgument("crops should have shape [", block_rank, diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 5e9280c1fe692037b0a842a92ef5a8c28b854a54..ad6b334326a470442c8c0d79b725345d4165be10 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -20,7 +20,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -165,12 +167,8 @@ XLA_MAKE_BINARY( xla::Div(xla::Mul(rhs, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), lhs, extend_dimensions)); -static xla::XlaOp Square(xla::XlaBuilder* builder, const xla::XlaOp& x) { - return xla::Mul(x, x); -} - XLA_MAKE_BINARY(SquaredDifference, - Square(b, xla::Sub(lhs, rhs, extend_dimensions))); + xla::Square(xla::Sub(lhs, rhs, extend_dimensions))); XLA_MAKE_BINARY(TruncateDiv, xla::Div(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(TruncateMod, xla::Rem(lhs, rhs, extend_dimensions)); @@ -195,8 +193,8 @@ XLA_MAKE_BINARY(SoftplusGrad, // softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2 XLA_MAKE_BINARY(SoftsignGrad, xla::Div(lhs, - Square(b, xla::Add(XlaHelpers::One(b, input_type(0)), - xla::Abs(rhs))))); + xla::Square(xla::Add(XlaHelpers::One(b, input_type(0)), + xla::Abs(rhs))))); XLA_MAKE_BINARY(TanhGrad, xla::Mul(rhs, xla::Sub(XlaHelpers::One(b, input_type(0)), @@ -204,6 +202,8 @@ XLA_MAKE_BINARY(TanhGrad, XLA_MAKE_BINARY(Pow, xla::Pow(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(NextAfter, xla::NextAfter(lhs, rhs)); + #undef XLA_MAKE_BINARY class ApproximateEqualOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index 8cc2479dd555380da7500abe6b2aca380110333b..ca2152d6c103e05c06809d85d9529720ff112217 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -19,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -31,6 +33,7 @@ class CastOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_)); OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(src_dtype_, &src_type_)); OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dst_dtype_, &dst_type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate", &use_truncation_)); } void Compile(XlaOpKernelContext* ctx) override { @@ -48,6 +51,36 @@ class CastOp : public XlaOpKernel { // imaginary part. output = xla::ConvertElementType(xla::Real(input), dst_type_); } else { + if (use_truncation_) { + OP_REQUIRES( + ctx, + xla::primitive_util::IsFloatingPointType(src_type_) && + xla::primitive_util::IsFloatingPointType(dst_type_), + errors::Unimplemented("Truncate attribute is only " + "implemented for floating point datatypes.")); + int mantissa_difference = + xla::primitive_util::SignificandWidth(src_type_) - + xla::primitive_util::SignificandWidth(dst_type_); + OP_REQUIRES(ctx, mantissa_difference > 0, + errors::Unimplemented( + "Truncate attribute is only implemented in cases where " + "dst datatype " + "has fewer mantissa bits than the src datatype")); + int src_bitwidth = xla::primitive_util::BitWidth(src_type_); + + // Bitcast to same-width integer, mask off the LSBs, bitcast back to the + // source datatype. + int64 mask = ~((1L << mantissa_difference) - 1); + xla::PrimitiveType same_width_int = + xla::primitive_util::UnsignedIntegralTypeForBitWidth(src_bitwidth); + OP_REQUIRES(ctx, same_width_int != xla::PRIMITIVE_TYPE_INVALID, + errors::Unimplemented("Unexpected type bitwidth")); + input = xla::BitcastConvertType( + xla::And( + xla::BitcastConvertType(input, same_width_int), + ::tensorflow::IntegerLiteral(builder, same_width_int, mask)), + src_type_); + } output = xla::ConvertElementType(input, dst_type_); } @@ -57,6 +90,7 @@ class CastOp : public XlaOpKernel { protected: DataType src_dtype_, dst_dtype_; xla::PrimitiveType src_type_, dst_type_; + bool use_truncation_; TF_DISALLOW_COPY_AND_ASSIGN(CastOp); }; @@ -79,8 +113,8 @@ class BitcastOp : public XlaOpKernel { if (src_dtype_ == dst_dtype_) { output = input; } else { - // The only complex type in XLA is C64, so error out if the bitcast has a - // complex source or destination type and the bitcast is not trivial. + // Error out if the bitcast has a complex source or destination type and + // the bitcast is not trivial. OP_REQUIRES(ctx, !xla::primitive_util::IsComplexType(src_type_) && !xla::primitive_util::IsComplexType(dst_type_), diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 7199b9b6feb36dd45ef51f4c38463bc715fcc38a..c2b4c28d1566f5429c5d8109db94af0c3762b131 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -99,8 +99,8 @@ class CategoricalOp : public XlaOpKernel { xla::PrimitiveType xla_output_type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(output_type(0), &xla_output_type)); - xla::XlaOp argmax = XlaHelpers::ArgMax(softmax_entries, xla_output_type, - /*axis=*/class_dimension); + xla::XlaOp argmax = xla::ArgMax(softmax_entries, xla_output_type, + /*axis=*/class_dimension); if (num_samples == 1) { argmax = xla::Reshape(argmax, {batch_size, 1}); } diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index cd7c7f4a82df7a65829787efcb1fd2f77870e945..91e4d9cea7cbf6075e30250587044174c4b8e7f4 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -24,13 +24,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index dff8af800229b9605bb93e0498bc5e5cf012f244..ff6c54e47c62f0555ef045e25051f6ec5a3c1d39 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -83,6 +83,17 @@ class ConstOp : public XlaOpKernel { return; } break; + case DT_COMPLEX128: + if (proto_.scomplex_val_size() == 2) { + ctx->SetOutput( + 0, + xla::Broadcast(xla::ConstantR0( + b, xla::complex128(proto_.dcomplex_val(0), + proto_.dcomplex_val(1))), + shape.dim_sizes())); + return; + } + break; case DT_INT32: if (proto_.int_val_size() == 1) { ctx->SetOutput( diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index b0bc7640307149459a29e6b0b2e8e8132e4141c9..5f99b24e221ba6c926032ef7a1b4bf1e92df7a68 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -26,13 +26,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/util/padding.h" @@ -212,8 +212,8 @@ Status ConvBackpropComputeDimensionsV2XlaShapes( XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape)); return ConvBackpropComputeDimensionsV2( label, num_spatial_dims, input_tensor_shape, filter_tensor_shape, - out_backprop_tensor_shape, dilations, strides, padding, data_format, - dims); + out_backprop_tensor_shape, dilations, strides, padding, + /*explicit_paddings=*/{}, data_format, dims); } } // anonymous namespace @@ -227,6 +227,11 @@ xla::StatusOr ConvOpAttrs::Create(int num_spatial_dims, TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations)); TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides)); TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding)); + // TODO(reedwm): Support explicit padding. + if (attrs.padding == EXPLICIT) { + return errors::Unimplemented( + "XLA does not yet support Conv2D with explicit padding."); + } string data_format; TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format)); @@ -428,23 +433,14 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); - // The conversion logic below assumes that the data format is NHWC, so we also - // check that here. bool use_batch_group_count = - filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise && - attrs.data_format == FORMAT_NHWC; + filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise; std::vector> padding(attrs.num_spatial_dims); std::vector rhs_dilation(attrs.num_spatial_dims); std::vector window_strides(attrs.num_spatial_dims); std::vector ones(attrs.num_spatial_dims, 1); - // The activations (inputs) form the LHS of the convolution. - // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] - // For the gradient computation, we flip the roles of the batch and - // feature dimensions. - // Each spatial entry has size in_depth * batch - // Swap n_dim and c_dim in the activations. dnums.set_input_batch_dimension(c_dim); dnums.set_input_feature_dimension(n_dim); @@ -478,7 +474,7 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( // convolution, we get the right size for the filter. // The padded_in_rows should be such that when we convolve this with the // expanded_out_rows as a filter, we should get filter_rows back. - // + const int64 padded_in_size = dims.spatial_dims[i].expanded_output_size + (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim]; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index eafdba876ae9e2c38694f065cf83bb3725b8460e..52c3c2c4a903a8c51f6b511774bc0312d39df826 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -25,13 +25,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/util/padding.h" diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 6e6ba21daf5bf3eab5bfc15378e77b6dd253da7c..b119997cf39e210ed8e0ae730a08829e72b238b4 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -22,10 +22,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index 6df8b5367d2390e65995beb1583b225755e6ee9f..a623585aad3b1b8f1f096ca527e7694d74f1ba46 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -21,12 +21,12 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/util/padding.h" diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 41c31d0ed58fe9bc9bbde0bd58993c975f04fd60..6472045265e4d930a5da770a68f5c502192201ae 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -167,13 +167,13 @@ class GatherOp : public XlaOpKernel { OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis)); const auto params_dims = input_shape.dims(); - if (axis < 0) { - axis += params_dims; - } OP_REQUIRES( - context, 0 <= axis && axis < params_dims, + context, -params_dims <= axis && axis < params_dims, errors::InvalidArgument("Expected axis in the range [", -params_dims, ", ", params_dims, "), but got ", axis)); + if (axis < 0) { + axis += params_dims; + } } DataType index_type = input_type(1); diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 4f0f0fd9aefecc3d31f8bd9c8ca40ebb0860c82d..aa5637e2669555da17af8bb05ab08beeba6a89c3 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -80,7 +80,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arg.name = resource->name(); VLOG(2) << "Resource " << resource->name() << " type: " << DataTypeString(arg.type) - << " shape: " << arg.shape.DebugString() + << " shape: " << arg.HumanString() << " initialized: " << arg.initialized; num_resource_args++; @@ -89,7 +89,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arg.type = input_types_[i]; arg.shape = ctx->InputShape(i + 1); VLOG(2) << "Arg type: " << DataTypeString(arg.type) - << " shape: " << arg.shape.DebugString(); + << " shape: " << arg.HumanString(); } } @@ -150,12 +150,12 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1, errors::FailedPrecondition("Expected one input shape")); xla::Shape then_input_shape = then_result.xla_input_shapes[0]; - OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(then_input_shape), + OP_REQUIRES(ctx, then_input_shape.IsTuple(), errors::FailedPrecondition("Expected tuple shape")); OP_REQUIRES(ctx, else_result.xla_input_shapes.size() == 1, errors::FailedPrecondition("Expected one input shape")); xla::Shape else_input_shape = else_result.xla_input_shapes[0]; - OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(else_input_shape), + OP_REQUIRES(ctx, else_input_shape.IsTuple(), errors::FailedPrecondition("Expected tuple shape")); OP_REQUIRES(ctx, xla::ShapeUtil::Compatible(then_input_shape, else_input_shape), @@ -248,7 +248,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { xla::GetTupleElement(outputs, output_types_.size() + num_resource_args); auto shape_or = b->GetShape(token_output); OP_REQUIRES_OK(ctx, shape_or.status()); - OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), + OP_REQUIRES(ctx, shape_or.ValueOrDie().IsToken(), errors::FailedPrecondition( "Token output is not token type: ", xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index 96ddd42e2ae04d454e4fb85628d139e17a543d2e..92b20fe0ba5611ca5314cd954026f7b71ea75f84 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -25,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { @@ -185,19 +187,20 @@ class AdjustContrastOpV2 : public XlaOpKernel { factor_shape.DebugString())); xla::XlaBuilder* b = context->builder(); - xla::XlaOp input = context->Input(0); - xla::XlaOp factor = context->Input(1); - DataType type = context->input_type(0); + xla::XlaOp input = context->Input(0); + xla::XlaOp factor = XlaHelpers::ConvertElementType(context->Input(1), type); + const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); auto converted = XlaHelpers::ConvertElementType(input, accumulation_type); auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), *context->GetOrCreateAdd(accumulation_type), {height_dim, width_dim}); - auto output = XlaHelpers::ConvertElementType(reduce, type); - output = - xla::Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); + + auto output = xla::Div( + reduce, XlaHelpers::FloatLiteral(b, accumulation_type, height * width)); + output = XlaHelpers::ConvertElementType(output, type); std::vector broadcast_dims(input_shape.dims() - 2); std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); @@ -233,8 +236,10 @@ class AdjustSaturationOp : public XlaOpKernel { channels, " channels.")); xla::XlaBuilder* b = context->builder(); - xla::XlaOp input = context->Input(0); - xla::XlaOp scale = context->Input(1); + xla::XlaOp input = + XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT); + xla::XlaOp scale = + XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT); DataType type = context->input_type(0); @@ -249,15 +254,17 @@ class AdjustSaturationOp : public XlaOpKernel { /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; channel_shape.set_dim(channel_dim, 1); - auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), - channel_shape); + auto hsv = + RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape); - hsv[1] = xla::Clamp(XlaHelpers::Zero(b, type), xla::Mul(hsv[1], scale), - XlaHelpers::One(b, type)); + hsv[1] = xla::Clamp(XlaHelpers::Zero(b, DT_FLOAT), xla::Mul(hsv[1], scale), + XlaHelpers::One(b, DT_FLOAT)); - auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); + auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT); - context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim)); + auto output = XlaHelpers::ConvertElementType( + xla::ConcatInDim(b, rgb, channel_dim), type); + context->SetOutput(0, output); } }; REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp); @@ -283,8 +290,10 @@ class AdjustHueOp : public XlaOpKernel { channels, " channels.")); xla::XlaBuilder* b = context->builder(); - xla::XlaOp input = context->Input(0); - xla::XlaOp delta = context->Input(1); + xla::XlaOp input = + XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT); + xla::XlaOp delta = + XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT); DataType type = context->input_type(0); @@ -299,20 +308,22 @@ class AdjustHueOp : public XlaOpKernel { /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; channel_shape.set_dim(channel_dim, 1); - auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), - channel_shape); + auto hsv = + RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape); - auto zero = XlaHelpers::Zero(b, type); - auto one = XlaHelpers::One(b, type); + auto zero = XlaHelpers::Zero(b, DT_FLOAT); + auto one = XlaHelpers::One(b, DT_FLOAT); auto& hue = hsv[0]; hue = xla::Rem(xla::Add(hsv[0], delta), one); hue = xla::Select(xla::Lt(hue, zero), xla::Rem(xla::Add(one, hue), one), hue); - auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); + auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT); - context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim)); + auto output = XlaHelpers::ConvertElementType( + xla::ConcatInDim(b, rgb, channel_dim), type); + context->SetOutput(0, output); } }; REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp); @@ -351,24 +362,26 @@ struct SuppressBodyFn { auto num_outputs_so_far = values[1]; auto iou_mask = values[2]; auto included_iou = values[3]; - auto zero_r1 = xla::ConstantR1(builder, {0}); + auto zero = xla::ConstantR0(builder, 0); // Determine if current elem is active using a slice. - auto row_idx_r1 = xla::Reshape(row_idx, {1}); - auto active_elem = xla::DynamicSlice(included_iou, row_idx_r1, {1}); + // TODO(b/118437727): The only reason we need an explicit vector is because + // some old GCCs can't deduce the right type for MakeConstSpan, and + // providing a single-value initializer list directly uses the wrong + // overload. Delete this once the deprecated overload is gone. + std::vector row_idx_vector = {row_idx}; + auto active_elem = xla::DynamicSlice(included_iou, row_idx_vector, {1}); active_elem = xla::Reshape(active_elem, {}); // Increment output count iff current elem is not suppressed. num_outputs_so_far = xla::Select( active_elem, num_outputs_so_far + xla::ConstantR0(builder, 1), num_outputs_so_far); // Slice out the row_idx. - auto starts = xla::ConcatInDim(builder, {row_idx_r1, zero_r1}, 0); - auto row_iou = xla::DynamicSlice(iou_mask, starts, {1, num_boxes}); + auto row_iou = xla::DynamicSlice(iou_mask, {row_idx, zero}, {1, num_boxes}); // Remove the diagonal from consideration. An elem cannot suppress // itself. - auto update_starts = xla::ConcatInDim(builder, {zero_r1, row_idx_r1}, 0); row_iou = xla::DynamicUpdateSlice( row_iou, xla::ConstantR2FromArray2D(builder, {{false}}), - update_starts); + {zero, row_idx}); // Create a suppression by inverting polarity. row_iou = xla::Reshape(row_iou, {num_boxes}); auto supp_mask = xla::Not(row_iou); diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 5a10c52ba8b6d4fab73f0dda67cbd52fd625e76b..b96d45316f626e678a64392a4315979eeeb6e83c 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -72,10 +72,10 @@ namespace { // from in_size to out_size. struct ResizeConvolutionDims { // Size of the kernel to use. - std::vector kernel_size; + std::vector kernel_size; // k // Stride of the convolution to use. - std::vector stride; + std::vector stride; // S }; ResizeConvolutionDims ComputeResizeConvolutionParameters( absl::Span in_size, absl::Span out_size, @@ -117,8 +117,10 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters( // + dims.stride * (out_size - 1) int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size, int64 stride) { - return (2 * kernel_size - 1) + (out_size - 1) * stride - (kernel_size - 1) - - 1 - (kernel_size * (in_size - 1)); + int64 padding = (2 * kernel_size - 1) + (out_size - 1) * stride - + (kernel_size - 1) - 1 - (kernel_size * (in_size - 1)); + + return padding; } // Form a 2D convolution kernel like: @@ -132,7 +134,7 @@ int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size, // If the 2D kernel would be very large, the 1D kernel can be applied once in // each dimension due to the symmetry of the kernel along all axis to reduce the // computational intensity. -xla::XlaOp Make1DKernel(xla::XlaBuilder* builder, int64 n) { +xla::XlaOp MakeBilinear1DKernel(xla::XlaBuilder* builder, int64 n) { std::vector kernel(n * 2 - 1); for (int64 i = 0; i < n; ++i) { float v = (i + 1.0f) / n; @@ -142,43 +144,64 @@ xla::XlaOp Make1DKernel(xla::XlaBuilder* builder, int64 n) { return xla::ConstantR1(builder, kernel); } +// Unlike the bilinear kernel, which is triangular, the nearest neighbor +// kernel is a square. For example, a 1D kernel with n=3 would look like +// [0 1 1 1 0] +// and n=4 would look like +// [0 0 1 1 1 1 0]. +// Note that in the second case, the kernel is not symmetric and we default +// to the right (because an existing non TPU kernel +// for nearest neighbor resize already chose to default to the right, +// so we want to be consistent). +xla::XlaOp MakeNearestNeighbor1DKernel(xla::XlaBuilder* builder, int64 n) { + std::vector kernel(n * 2 - 1, 0.0f); + std::fill(&kernel[n / 2], &kernel[(3 * n) / 2], 1.0f); + + return xla::ConstantR1(builder, kernel); +} + // Kernels with more than 16 spatial elements are considered intense and the -// kernel should applied to each dimension independently. +// kernel should be applied to each dimension independently. const int64 kMax2DKernelSize = 16; -xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, - absl::Span kernel_size, - int64 channels) { +xla::XlaOp MakeGeneralResizeKernel(xla::XlaBuilder* builder, + absl::Span kernel_size, + int64 channels, bool is_kernel_bilinear) { + auto make_kernel_func = + is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel; + auto depthwise_kernel = xla::Broadcast( xla::Zero(builder, xla::F32), {(2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1}); return xla::Mul( - xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[1]), + xla::Add(depthwise_kernel, make_kernel_func(builder, kernel_size[1]), /*broadcast_dimensions=*/{1}), - Make1DKernel(builder, kernel_size[0]), + make_kernel_func(builder, kernel_size[0]), /*broadcast_dimensions=*/{0}); } -xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, - absl::Span kernel_size, - int64 channels, int64 dim) { +xla::XlaOp MakeGeneralResizeKernelInDim(xla::XlaBuilder* builder, + absl::Span kernel_size, + int64 channels, int64 dim, + bool is_kernel_bilinear) { + auto make_kernel_func = + is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel; + auto depthwise_kernel = xla::Broadcast(xla::Zero(builder, xla::F32), {dim == 0 ? (2 * kernel_size[0] - 1) : 1, dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1}); - return xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[dim]), + return xla::Add(depthwise_kernel, make_kernel_func(builder, kernel_size[dim]), /*broadcast_dimensions=*/{dim}); } -xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, - const xla::XlaOp& input, - const int num_spatial_dims, - std::vector in_size, - std::vector out_size, - const int64 channels, - const bool align_corners) { - // Picture for a 1x3 to 1x4 resize: +xla::XlaOp ResizeUsingDilationAndConvolution( + xla::XlaBuilder* builder, const xla::XlaOp& input, + const int num_spatial_dims, std::vector in_size, + std::vector out_size, const int64 channels, const bool align_corners, + bool is_kernel_bilinear) { + // Picture for a 1x3 to 1x4 bilinear resize: // stride = 2, kernel size = 3 // Input: // 3 6 9 @@ -264,8 +287,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, // Split convolutions into independent dimensions if they would be a very // large kernel. if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { - xla::XlaOp kernel = - MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + xla::XlaOp kernel = MakeGeneralResizeKernel(builder, dims.kernel_size, + channels, is_kernel_bilinear); output = xla::ConvGeneralDilated(input_data, kernel, dims.stride, /*padding=*/ @@ -275,8 +298,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, /*rhs_dilation=*/{1, 1}, dimension_numbers, /*feature_group_count=*/channels); } else { - xla::XlaOp kernel0 = - MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); + xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim( + builder, dims.kernel_size, channels, 0, is_kernel_bilinear); output = xla::ConvGeneralDilated( input_data, kernel0, {dims.stride[0], 1}, /*padding=*/ @@ -284,8 +307,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, /*lhs_dilation=*/{dims.kernel_size[0], 1}, /*rhs_dilation=*/{1, 1}, dimension_numbers, /*feature_group_count=*/channels); - xla::XlaOp kernel1 = - MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); + xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim( + builder, dims.kernel_size, channels, 1, is_kernel_bilinear); output = xla::ConvGeneralDilated( output, kernel1, {1, dims.stride[1]}, /*padding=*/ @@ -306,13 +329,11 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, return output; } -xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, - const xla::XlaOp& grad, - const int num_spatial_dims, - std::vector in_size, - std::vector grad_size, - const int64 channels, - const bool align_corners) { +xla::XlaOp ResizeUsingDilationAndConvolutionGradOp( + xla::XlaBuilder* builder, const xla::XlaOp& grad, + const int num_spatial_dims, std::vector in_size, + std::vector grad_size, const int64 channels, + const bool align_corners, bool is_kernel_bilinear) { ResizeConvolutionDims dims = ComputeResizeConvolutionParameters(in_size, grad_size, align_corners); @@ -332,8 +353,8 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); xla::XlaOp output; if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { - xla::XlaOp kernel = - MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + xla::XlaOp kernel = MakeGeneralResizeKernel(builder, dims.kernel_size, + channels, is_kernel_bilinear); // Broadcast the input kernel where the forward op expanded from a size == 1 // dimension to a size > 1 dimension. This has the effect of summing the @@ -355,14 +376,14 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, /*rhs_dilation=*/{1, 1}, dimension_numbers, /*feature_group_count=*/channels); } else { - xla::XlaOp kernel0 = - MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); - xla::XlaOp kernel1 = - MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); - - // Broadcast the input kernel where the forward op expanded from a size == 1 - // dimension to a size > 1 dimension. This has the effect of summing the - // gradient contributions in that dimension. + xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim( + builder, dims.kernel_size, channels, 0, is_kernel_bilinear); + xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim( + builder, dims.kernel_size, channels, 1, is_kernel_bilinear); + + // Broadcast the input kernel where the forward op expanded from a + // size == 1 dimension to a size > 1 dimension. This has the effect of + // summing the gradient contributions in that dimension. if (in_size[0] == 1 && grad_size[0] > 1) { kernel0 = xla::Add(kernel0, xla::ConstantR1(builder, grad_size[0], 0), @@ -407,109 +428,139 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, return output; } -class ResizeBilinearOp : public XlaOpKernel { - public: - explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); +void GeneralCompile(XlaOpKernelContext* ctx, bool align_corners_, + bool is_kernel_bilinear) { + xla::XlaBuilder* b = ctx->builder(); + + TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input_shape.DebugString())); + // First dimension always assumed to be batch + const int64 batch = input_shape.dim_size(0); + std::vector in_size = {input_shape.dim_size(1), + input_shape.dim_size(2)}; + // Last/4th dimension always assumed to be num channels + const int64 channels = input_shape.dim_size(3); + OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0, + errors::InvalidArgument("input size must be positive, got [", + in_size[0], ",", in_size[1], "]")); + + std::vector out_size; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &out_size)); + OP_REQUIRES(ctx, out_size.size() == 2, + errors::InvalidArgument("output size must be length 2, got ", + out_size.size())); + OP_REQUIRES(ctx, out_size[0] > 0 && out_size[1] > 0, + errors::InvalidArgument("output size must be positive, got [", + out_size[0], ",", out_size[1], "]")); + + const int num_spatial_dims = 2; + + xla::XlaOp input = ctx->Input(0); + + // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in + // dimension i. + bool slice_input = false; + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] > 1 && out_size[i] == 1) { + // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first + // entry before resizing. + slice_input = true; + in_size[i] = 1; + } + } + if (slice_input) { + input = xla::Slice(input, {0, 0, 0, 0}, + {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); } - void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); + // Output is always type float. + input = xla::ConvertElementType(input, xla::F32); + + // Special Case: + // Instead of doing a ResizeUsingDilationAndConvolution directly, + // while (out_size[0]-1) = c * 2^x * (in_size[0]-1) for x>1 c>1, resize the + // image to 2*(in_size[0]-1)+1 x-times and then resize by scale c(int here). + // Instead of resizing directly we resize it iteratively. + // + // Since bilinear resize can be broken down as 2 sequential linear + // operations along different dimensions. + // Given sufficient numerical stability and a cxd is same as resizing axb -> exf -> cxd. + // This does not work in the case of align_corners_=false because of special + // padding requirements that cause multiple resizes to be very different + // from a single resize. + // + // This makes the convolutions kernels smaller and the operation faster. + xla::XlaOp output = input; + while (in_size != out_size) { + if (in_size[0] != 1 && in_size[1] != 1) { + std::vector k = { + (static_cast(out_size[0]) - 1) / ((in_size[0] - 1) * 2), + (static_cast(out_size[1]) - 1) / ((in_size[1] - 1) * 2)}; + if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) && + k[0] > 1 && k[1] > 1 && align_corners_) { + std::vector next_out_size = {(in_size[0] - 1) * 2 + 1, + (in_size[1] - 1) * 2 + 1}; + output = ResizeUsingDilationAndConvolution( + b, input, num_spatial_dims, in_size, next_out_size, channels, + align_corners_, is_kernel_bilinear); + input = output; + in_size = next_out_size; + } else { + output = ResizeUsingDilationAndConvolution( + b, input, num_spatial_dims, in_size, out_size, channels, + align_corners_, is_kernel_bilinear); + in_size = out_size; + } + } else { + output = ResizeUsingDilationAndConvolution( + b, input, num_spatial_dims, in_size, out_size, channels, + align_corners_, is_kernel_bilinear); + in_size = out_size; + } + } - TensorShape input_shape = ctx->InputShape(0); - OP_REQUIRES(ctx, input_shape.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", - input_shape.DebugString())); - const int64 batch = input_shape.dim_size(0); - std::vector in_size = {input_shape.dim_size(1), - input_shape.dim_size(2)}; - const int64 channels = input_shape.dim_size(3); - OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0, - errors::InvalidArgument("input size must be positive, got [", - in_size[0], ",", in_size[1], "]")); + ctx->SetOutput(0, output); +} - std::vector out_size; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &out_size)); - OP_REQUIRES(ctx, out_size.size() == 2, - errors::InvalidArgument("output size must be length 2, got ", - out_size.size())); - OP_REQUIRES(ctx, out_size[0] > 0 && out_size[1] > 0, - errors::InvalidArgument("output size must be positive, got [", - out_size[0], ",", out_size[1], "]")); +class ResizeNearestNeighborOp : public XlaOpKernel { + public: + explicit ResizeNearestNeighborOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); + OP_REQUIRES( + ctx, align_corners_ == true, + errors::Unimplemented("ResizeNearestNeighbor with align_corners=False " + "is not yet implemented")); + } - const int num_spatial_dims = 2; + void Compile(XlaOpKernelContext* ctx) override { + GeneralCompile(ctx, align_corners_, is_kernel_bilinear_); + } - xla::XlaOp input = ctx->Input(0); + private: + bool align_corners_ = true; + bool is_kernel_bilinear_ = false; +}; - // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in - // dimension i. - bool slice_input = false; - for (int i = 0; i < num_spatial_dims; ++i) { - if (in_size[i] > 1 && out_size[i] == 1) { - // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first - // entry before resizing. - slice_input = true; - in_size[i] = 1; - } - } - if (slice_input) { - input = - xla::Slice(input, {0, 0, 0, 0}, - {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); - } +REGISTER_XLA_OP(Name("ResizeNearestNeighbor").CompileTimeConstantInput("size"), + ResizeNearestNeighborOp); - // Output is always type float. - input = xla::ConvertElementType(input, xla::F32); - - // Special Case: - // Instead of doing a ResizeUsingDilationAndConvolution directly, - // while (out_size[0]-1) = c * 2^x * (in_size[0]-1) for x>1 c>1, resize the - // image to 2*(in_size[0]-1)+1 x-times and then resize by scale c(int here). - // Instead of resizing directly we resize it iteratively. - // - // Since bilinear resize can be broken down as 2 sequential linear - // operations along different dimensions. - // Given sufficient numerical stability and a cxd is same as resizing axb -> exf -> cxd. - // This does not work in the case of align_corners_=false because of special - // padding requirements that cause multiple resizes to be very different - // from a single resize. - // - // This makes the convolutions kernels smaller and the operation faster. - xla::XlaOp output = input; - while (in_size != out_size) { - if (in_size[0] != 1 && in_size[1] != 1) { - std::vector k = { - (static_cast(out_size[0]) - 1) / ((in_size[0] - 1) * 2), - (static_cast(out_size[1]) - 1) / ((in_size[1] - 1) * 2)}; - if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) && - k[0] > 1 && k[1] > 1 && align_corners_) { - std::vector next_out_size = {(in_size[0] - 1) * 2 + 1, - (in_size[1] - 1) * 2 + 1}; - output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, - in_size, next_out_size, - channels, align_corners_); - input = output; - in_size = next_out_size; - } else { - output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, - in_size, out_size, - channels, align_corners_); - in_size = out_size; - } - } else { - output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, - in_size, out_size, channels, - align_corners_); - in_size = out_size; - } - } +class ResizeBilinearOp : public XlaOpKernel { + public: + explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); + } - ctx->SetOutput(0, output); + void Compile(XlaOpKernelContext* ctx) override { + GeneralCompile(ctx, align_corners_, is_kernel_bilinear_); } private: - bool align_corners_; + bool align_corners_ = true; + bool is_kernel_bilinear_ = true; }; REGISTER_XLA_OP(Name("ResizeBilinear").CompileTimeConstantInput("size"), @@ -581,19 +632,19 @@ class ResizeBilinearGradOp : public XlaOpKernel { (in_size[1] - 1) * 2 + 1}; output = ResizeUsingDilationAndConvolutionGradOp( b, grad, num_spatial_dims, in_size, next_grad_size, channels, - align_corners_); + align_corners_, true); grad = output; in_size = next_grad_size; } else { output = ResizeUsingDilationAndConvolutionGradOp( b, grad, num_spatial_dims, in_size, grad_size, channels, - align_corners_); + align_corners_, true); in_size = grad_size; } } else { output = ResizeUsingDilationAndConvolutionGradOp( b, grad, num_spatial_dims, in_size, grad_size, channels, - align_corners_); + align_corners_, true); in_size = grad_size; } } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index 843b6bb4e658af16fd753c1a20b35dd3d18df027..c1539f48d4f729510b2d930de91666a7c31f1ef0 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -18,17 +18,16 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/index_ops.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { XlaArgMinMaxOp::XlaArgMinMaxOp(OpKernelConstruction* ctx, bool is_min) @@ -66,9 +65,9 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { xla::XlaOp input = ctx->Input(0); xla::XlaOp output; if (is_min_) { - output = XlaHelpers::ArgMin(input, index_xla_type, axis); + output = xla::ArgMin(input, index_xla_type, axis); } else { - output = XlaHelpers::ArgMax(input, index_xla_type, axis); + output = xla::ArgMax(input, index_xla_type, axis); } ctx->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index e2c05b648bb194b1b452c527ddb1a2c5995b1217..e4bbdef6480104a1051acfc647644deb65c80171 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -16,16 +16,16 @@ limitations under the License. // Native XLA implementations of indexing ops. #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { @@ -74,7 +74,7 @@ class ArgMaxCustomCallOp : public XlaOpKernel { // shape isn't supported. if (!ctx->compiler()->options().allow_cpu_custom_calls || (input_dims != 1 && input_dims != 2)) { - xla::XlaOp output = XlaHelpers::ArgMax(ctx->Input(0), output_type, axis); + xla::XlaOp output = xla::ArgMax(ctx->Input(0), output_type, axis); ctx->SetOutput(0, output); return; } @@ -110,8 +110,8 @@ class ArgMaxCustomCallOp : public XlaOpKernel { auto shape_status = b.GetShape(arg); OP_REQUIRES_OK(ctx, shape_status.status()); xla::Shape arg_shape = shape_status.ConsumeValueOrDie(); - *arg_shape.mutable_layout() = xla::LayoutUtil::MakeDescendingLayout( - xla::ShapeUtil::Rank(arg_shape)); + *arg_shape.mutable_layout() = + xla::LayoutUtil::MakeDescendingLayout(arg_shape.rank()); arg_shapes.push_back(std::move(arg_shape)); } diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index 6440770c29894c951f010f6c1deb929f4fe79bbf..f36e0025250b3a196b31755a1ddf6620c415b6a3 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -24,8 +24,8 @@ limitations under the License. namespace tensorflow { namespace { -constexpr std::array kMatmulTypes = { - {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; +constexpr std::array kMatmulTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}}; class MatMulOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index f6b8534f4d7c537e5b708ee000e00cb92123584b..656f9b898f32dfc05215014f51c2bbaf07580836 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -38,8 +38,7 @@ class MirrorPadOp : public XlaOpKernel { // - [1, 2, 3, 3, 2] in symmetric mode. int64 excluded_edges = mode == MirrorPadMode::REFLECT ? 1 : 0; xla::XlaOp accum = t; - for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; - --dimno) { + for (int64 dimno = original_shape.rank() - 1; dimno >= 0; --dimno) { auto t_rev = xla::Rev(accum, {dimno}); int64 lhs_padding = pad_literal.Get({dimno, 0}); int64 rhs_padding = pad_literal.Get({dimno, 1}); diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index a9b519d8928cc2807831fd6b4f12e60b7d58ea55..426a0941df57f19072d1cb9f3fa3d0079db465c5 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -24,12 +24,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 06c6cc37ec90192486ba15010bfeb763a9ffb987..23bb050a34d9246cdf73090aa6adfca054bf8bcf 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -26,10 +26,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/pooling_ops_common.h" diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 2d92056e4f522f6206e7d632f0fa1e8b793fd6e3..01b047f732f0e9fb3b45b272e7886e2f8cf4fff4 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -160,17 +160,24 @@ class RandomShuffleOp : public XlaOpKernel { -> xla::StatusOr> { auto swaps = loop_vars[0]; auto indices = loop_vars[1]; - i = xla::Reshape(i, {1}); + // TODO(b/118437727): The absl::Span nonsense is only necessary because + // the deprecated overload creates ambiguity for the single-element span + // case. Remove it once the deprecated overload is gone. // temp = indices[i] - auto temp = xla::DynamicSlice(indices, i, {1}); + auto temp = + xla::DynamicSlice(indices, absl::Span({i}), {1}); // swap_index = swaps[i] - auto swap_index = xla::DynamicSlice(swaps, i, {1}); + auto swap_index = xla::Reshape( + xla::DynamicSlice(swaps, absl::Span({i}), {1}), {}); // swap_value = indices[swaps[i]] - auto swap_value = xla::DynamicSlice(indices, swap_index, {1}); + auto swap_value = xla::DynamicSlice( + indices, absl::Span({swap_index}), {1}); // indices[i] = indices[swaps[i]] - indices = xla::DynamicUpdateSlice(indices, swap_value, i); + indices = xla::DynamicUpdateSlice(indices, swap_value, + absl::Span({i})); // indices[swaps[i]] = temp - indices = xla::DynamicUpdateSlice(indices, temp, swap_index); + indices = xla::DynamicUpdateSlice( + indices, temp, absl::Span({swap_index})); return std::vector{swaps, indices}; }; // for i in range(n): diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 4b9e1a578be2445091228953df7e5c5e82b42c28..daefdfc58a4957d9e685d25aa90da6218f2041ad 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -23,13 +23,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 9e4c57c9bf73369662274f6b783418e18ff860c2..aaf8c6075dd292e33e70683774a6c1bf374183e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index b1fa2915d59e4e5e2f2523e20e9a37898d087117..7a620d2a6518f8686ef570b33aac971d1dccb6c1 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -157,9 +157,11 @@ class LinSpaceOp : public XlaOpKernel { flat(0) = start; } else { const float step = (stop - start) / (num - 1); - for (int64 i = 0; i < num; ++i) { + for (int64 i = 0; i < num - 1; ++i) { flat(i) = start + step * i; } + // The last value in the sequence must be equal to stop. + flat(num - 1) = stop; } break; } @@ -171,9 +173,11 @@ class LinSpaceOp : public XlaOpKernel { flat(0) = start; } else { const double step = (stop - start) / (num - 1); - for (int64 i = 0; i < num; ++i) { + for (int64 i = 0; i < num - 1; ++i) { flat(i) = start + step * i; } + // The last value in the sequence must be equal to stop. + flat(num - 1) = stop; } break; } diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 12830816ec16c9797f0fe4d8f3f13f5a8176161d..31d4cc131600f360c764ffa02831046c85d846e5 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -20,10 +20,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { @@ -91,14 +92,20 @@ class SizeOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_shape = ctx->InputShape(0); - const int64 size = input_shape.num_elements(); - OP_REQUIRES(ctx, FastBoundsCheck(size, std::numeric_limits::max()), + OP_REQUIRES(ctx, + FastBoundsCheck(input_shape.num_elements(), + std::numeric_limits::max()), errors::InvalidArgument("Size does not work for tensors > " "int32 max.")); Tensor size_constant(DT_INT32, TensorShape({})); - size_constant.scalar()() = static_cast(size); - - ctx->SetConstantOutput(0, size_constant); + const int rank = input_shape.dims(); + xla::XlaBuilder* builder = ctx->builder(); + auto size = xla::One(builder, xla::U32); + for (int64 i = 0; i < rank; ++i) { + size = xla::Mul(size, xla::GetDimensionSize(ctx->Input(0), i)); + } + size = xla::ConvertElementType(size, xla::S32); + ctx->SetOutput(0, size); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.cc b/tensorflow/compiler/tf2xla/kernels/shape_util.cc index 76ea5f525598f511f295eb5a30f3cf603fbf57aa..b18e3f965c427aec456ce2b188dad79485df23cc 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_util.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/framework/bounds_check.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 622efac81766fc3ddaf538b58170f34fce06927a..52bed2670b4b8408e3b2f72b64bf370aea5325f6 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -39,7 +39,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, OP_REQUIRES( ctx, - xla::ShapeUtil::Rank(paddings.shape()) == 2 && + paddings.shape().rank() == 2 && block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) && 2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1), errors::InvalidArgument("paddings should have shape [", block_rank, diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 8e9e4daf99d3dd3b8e149e3f3e5f6c27665c0fcb..b6c96b1f582710e1cc39e6e1e0e800ef8170743d 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -24,13 +24,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -45,7 +45,7 @@ Status GetStackShape(xla::XlaBuilder* builder, XlaResource* resource, return shape_or_status.status(); } xla::Shape shape = shape_or_status.ValueOrDie(); - TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape)); + TF_RET_CHECK(shape.IsTuple()); return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0), stack_shape); } @@ -146,9 +146,9 @@ class StackPushOp : public XlaOpKernel { xla::XlaOp value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + std::vector start_indices(elem_shape.dims() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = index; TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); @@ -202,9 +202,9 @@ class StackPopOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, resource->SetValue(xla::Tuple(b, {ta, index}))); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, stack_shape.dims() - 1}})); + std::vector start_indices(stack_shape.dims(), + xla::ConstantR0(b, 0)); + start_indices[0] = index; auto slice_shape = stack_shape.dim_sizes(); slice_shape[0] = 1LL; diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 10d990b3213ab882cf44a4df20a977633de3fdab..2273b592466431f59abcc43fcac4c37eecd53bff 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -288,19 +288,21 @@ class StridedSliceAssignOp : public XlaOpKernel { xla::XlaOp rhs = ctx->Input(4); absl::InlinedVector dimensions_to_reverse; - absl::InlinedVector slice_begin, slice_dims; + absl::InlinedVector slice_begin; + absl::InlinedVector slice_dims; for (int i = 0; i < begin.size(); ++i) { // TODO(phawkins): implement strides != 1 OP_REQUIRES( ctx, strides[i] == 1 || strides[i] == -1, errors::Unimplemented("Strides != 1 or -1 are not yet implemented")); if (strides[i] > 0) { - slice_begin.push_back(begin[i]); + slice_begin.push_back(xla::ConstantR0(ctx->builder(), begin[i])); slice_dims.push_back(end[i] - begin[i]); } else { // Negative stride: swap begin and end, add 1 because the interval // is semi-open, and mark the dimension to be reversed. - slice_begin.push_back(end[i] + 1); + slice_begin.push_back( + xla::ConstantR0(ctx->builder(), end[i] + 1)); slice_dims.push_back(begin[i] - end[i]); dimensions_to_reverse.push_back(i); } @@ -311,14 +313,7 @@ class StridedSliceAssignOp : public XlaOpKernel { } rhs = xla::Reshape(rhs, slice_dims); - if (lhs_shape.dims() == 0) { - // TODO(b/38323843): DynamicUpdateSlice crashes on rank 0 inputs. Fix - // and remove this workaround. - lhs = rhs; - } else { - lhs = xla::DynamicUpdateSlice( - lhs, rhs, xla::ConstantR1(ctx->builder(), slice_begin)); - } + lhs = xla::DynamicUpdateSlice(lhs, rhs, slice_begin); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); } diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 939d7e19515a1cb41e3e23e9d1fa957ae09ecab7..77a3e5c001e1c715f23ae5148f94dae2faa81acf 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -27,13 +27,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_resource.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -123,7 +123,8 @@ Status GetTensorArrayShape(const XlaResource* resource, xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand, const xla::XlaOp& update, absl::Span update_dims, - const xla::XlaOp& start_indices, DataType dtype) { + absl::Span start_indices, + DataType dtype) { xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims); xla::XlaOp sum = dtype == DT_BOOL ? xla::Or(current, update) : xla::Add(current, update); @@ -212,9 +213,9 @@ class TensorArrayWriteOp : public XlaOpKernel { xla::XlaOp flow = ctx->Input(3); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + std::vector start_indices(elem_shape.dims() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = index; TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); @@ -263,9 +264,9 @@ class TensorArrayReadOp : public XlaOpKernel { xla::XlaOp index = ctx->Input(1); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}})); + std::vector start_indices(ta_shape.dims(), + xla::ConstantR0(b, 0)); + start_indices[0] = index; auto slice_shape = ta_shape.dim_sizes(); slice_shape[0] = 1LL; @@ -419,10 +420,10 @@ class TensorArrayScatterOp : public XlaOpKernel { auto slice = xla::Slice(value, value_starts, value_ends, value_strides); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto index = xla::Slice(indices, {i}, {i + 1}, {1}); - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + auto index = xla::Reshape(xla::Slice(indices, {i}, {i + 1}, {1}), {}); + std::vector start_indices(elem_shape.dims() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = index; ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices, dtype_); } } diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 64a24703ae1460abfedb6d9298e1e164076a199a..65020012283d9c5f62e5e2fd11fc2bf1110e019a 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ // XLA TensorList operators. +// Tensor lists are represented as tuple consisting of a pre-allocated list +// consisting of the tensors (and where dim 0 is the list index), along with a +// scalar telling us the current number of elements. #include #include @@ -24,13 +27,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -45,11 +48,27 @@ Status GetTensorListShape(xla::XlaBuilder* builder, xla::XlaOp op, return shape_or_status.status(); } xla::Shape shape = shape_or_status.ValueOrDie(); - TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape)); + TF_RET_CHECK(shape.IsTuple()); return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0), tensor_list_shape); } +class TensorListLengthOp : public XlaOpKernel { + public: + explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp tl = ctx->Input(0); + xla::XlaOp index = xla::GetTupleElement(tl, 1); + ctx->SetOutput(0, index); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TensorListLengthOp); +}; + +REGISTER_XLA_OP(Name("TensorListLength"), TensorListLengthOp); + class TensorListReserveOp : public XlaOpKernel { public: explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -67,9 +86,10 @@ class TensorListReserveOp : public XlaOpKernel { tensor_shape.AppendShape(element_shape); xla::XlaBuilder* b = ctx->builder(); - ctx->SetOutput(0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), - tensor_shape.dim_sizes()), - xla::ConstantR0(b, 0)})); + ctx->SetTensorListOutput( + 0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), + tensor_shape.dim_sizes()), + xla::ConstantR0(b, num_elements)})); } private: @@ -85,19 +105,41 @@ REGISTER_XLA_OP(Name("TensorListReserve") class EmptyTensorListOp : public XlaOpKernel { public: - explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { - ctx->CtxFailure( + TensorShape element_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape)); + int64 max_num_elements; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements)); + OP_REQUIRES( + ctx, max_num_elements >= 0, errors::InvalidArgument("XLA compilation requires a fixed tensor list " - "size. Use TensorListReserve instead.")); + "size. Set the max number of elements.")); + + TensorShape tensor_shape; + tensor_shape.AddDim(max_num_elements); + tensor_shape.AppendShape(element_shape); + + xla::XlaBuilder* b = ctx->builder(); + ctx->SetTensorListOutput( + 0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), + tensor_shape.dim_sizes()), + xla::ConstantR0(b, 0)})); } private: + DataType dtype_; + TF_DISALLOW_COPY_AND_ASSIGN(EmptyTensorListOp); }; -REGISTER_XLA_OP(Name("EmptyTensorList"), EmptyTensorListOp); +REGISTER_XLA_OP(Name("EmptyTensorList") + .CompileTimeConstantInput("element_shape") + .CompileTimeConstantInput("max_num_elements"), + EmptyTensorListOp); class TensorListElementShapeOp : public XlaOpKernel { public: @@ -139,6 +181,136 @@ class TensorListElementShapeOp : public XlaOpKernel { REGISTER_XLA_OP(Name("TensorListElementShape"), TensorListElementShapeOp); +class TensorListGetItemOp : public XlaOpKernel { + public: + explicit TensorListGetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp state = ctx->Input(0); + + TensorShape shape; + OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape)); + + xla::XlaOp ta = xla::GetTupleElement(state, 0); + xla::XlaOp index = ctx->Input(1); + + // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. + std::vector start_indices(shape.dims(), + xla::ConstantR0(b, 0)); + start_indices[0] = index; + auto slice_shape = shape.dim_sizes(); + slice_shape[0] = 1LL; + + xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); + // Remove the leading '1' dimension. + std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); + + ctx->SetOutput(0, xla::Reshape(read, value_shape)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListGetItemOp); +}; + +REGISTER_XLA_OP(Name("TensorListGetItem"), TensorListGetItemOp); + +class TensorListStackOp : public XlaOpKernel { + public: + explicit TensorListStackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp state = ctx->Input(0); + xla::XlaOp ta = xla::GetTupleElement(state, 0); + ctx->SetOutput(0, ta); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListStackOp); +}; + +REGISTER_XLA_OP(Name("TensorListStack"), TensorListStackOp); + +class TensorListFromTensorOp : public XlaOpKernel { + public: + explicit TensorListFromTensorOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape element_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &element_shape)); + + const TensorShape tensor_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, tensor_shape.dims() > 0, + errors::InvalidArgument("Input value must be at least a " + "vector but received shape: ", + tensor_shape.DebugString())); + const int num_elements = tensor_shape.dim_size(0); + + xla::XlaBuilder* b = ctx->builder(); + const xla::XlaOp tensor = ctx->Input(0); + + ctx->SetTensorListOutput( + 0, xla::Tuple(b, {tensor, xla::ConstantR0(b, num_elements)})); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListFromTensorOp); +}; + +REGISTER_XLA_OP( + Name("TensorListFromTensor").CompileTimeConstantInput("element_shape"), + TensorListFromTensorOp); + +class TensorListSetItemOp : public XlaOpKernel { + public: + explicit TensorListSetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp tl = ctx->Input(0); + TensorShape elem_shape = ctx->InputShape(2); + + xla::XlaOp ta = xla::GetTupleElement(tl, 0); + xla::XlaOp index = ctx->Input(1); + xla::XlaOp value = ctx->Input(2); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + std::vector start_indices(elem_shape.dims() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = index; + + TensorShape slice_shape = elem_shape; + slice_shape.InsertDim(0, 1LL); + auto update = xla::Reshape(value, slice_shape.dim_sizes()); + + ctx->SetTensorListOutput( + 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices), + index + xla::ConstantR0(b, 1)})); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListSetItemOp); +}; + +REGISTER_XLA_OP(Name("TensorListSetItem"), TensorListSetItemOp); + class TensorListPushBackOp : public XlaOpKernel { public: explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -147,25 +319,23 @@ class TensorListPushBackOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp list = ctx->Input(0); + xla::XlaOp tl = ctx->Input(0); TensorShape elem_shape = ctx->InputShape(1); - xla::XlaOp ta = xla::GetTupleElement(list, 0); - xla::XlaOp index = xla::GetTupleElement(list, 1); + xla::XlaOp ta = xla::GetTupleElement(tl, 0); + xla::XlaOp index = xla::GetTupleElement(tl, 1); xla::XlaOp value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + std::vector start_indices(elem_shape.dims() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = index; TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); auto update = xla::Reshape(value, slice_shape.dim_sizes()); - // TODO(phawkins): We don't check the index is in bounds --- there is no - // error mechanism in XLA. - ctx->SetOutput( + ctx->SetTensorListOutput( 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices), index + xla::ConstantR0(b, 1)})); } @@ -197,20 +367,17 @@ class TensorListPopBackOp : public XlaOpKernel { index = index - xla::ConstantR0(b, 1); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, shape.dims() - 1}})); - + std::vector start_indices(shape.dims(), + xla::ConstantR0(b, 0)); + start_indices[0] = index; auto slice_shape = shape.dim_sizes(); slice_shape[0] = 1LL; - // TODO(phawkins): We don't check the index is in bounds --- there is no - // error mechanism in XLA. xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); - ctx->SetOutput(0, xla::Tuple(b, {ta, index})); + ctx->SetTensorListOutput(0, xla::Tuple(b, {ta, index})); ctx->SetOutput(1, xla::Reshape(read, value_shape)); } diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index c9b324a243e4cc3ec64daa3ca0d285336a0d0154..76793d677ba45f8e863e684a149da684c8ce8787 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -24,9 +24,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index 8671632976023fded04c26a9780c1a67638b0916..2fc5619de737b8977e4249e4d2297a0303c339ce 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -24,12 +24,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 2c92a585f5679242d672d0402e617ff199b94f17..dfa09b16081e93ba843a1858e68e6ff756de20c1 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -291,5 +291,19 @@ class ResourceScatterNdAddOp : public ResourceScatterOp { }; REGISTER_XLA_OP(Name("ResourceScatterNdAdd"), ResourceScatterNdAddOp); +class ResourceScatterNdSubOp : public ResourceScatterOp { + public: + explicit ResourceScatterNdSubOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/true, + /*combiner=*/Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Sub(x, y); + } +}; +REGISTER_XLA_OP(Name("ResourceScatterNdSub"), ResourceScatterNdSubOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index ff5255028bd012ea4d839faa59ef5930a17c5767..fd5ff10ae0a8cb39075fa6c594707dbc833f5f16 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -70,13 +70,20 @@ Status MakeXlaCompilerArgumentsFromInputs( arg.name = resource->name(); VLOG(2) << " resource " << resource->name() << " type: " << DataTypeString(arg.type) - << " shape: " << arg.shape.DebugString() + << " shape: " << arg.ShapeHumanString() << " initialized: " << arg.initialized; } else { arg.kind = XlaCompiler::Argument::kParameter; arg.type = ctx->input_type(i); - arg.shape = ctx->InputShape(i); + + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp handle = ctx->Input(i); + auto shape_or_status = builder->GetShape(handle); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + arg.shape = shape_or_status.ValueOrDie(); } } return Status::OK(); @@ -206,12 +213,12 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES(ctx, body.xla_input_shapes.size() == 1, errors::FailedPrecondition("Expected one input shape")); xla::Shape body_input_shape = body.xla_input_shapes[0]; - OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(body_input_shape), + OP_REQUIRES(ctx, body_input_shape.IsTuple(), errors::FailedPrecondition("Expected tuple shape")); OP_REQUIRES(ctx, cond.xla_input_shapes.size() == 1, errors::FailedPrecondition("Expected one input shape")); xla::Shape cond_input_shape = cond.xla_input_shapes[0]; - OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(cond_input_shape), + OP_REQUIRES(ctx, cond_input_shape.IsTuple(), errors::FailedPrecondition("Expected tuple shape")); VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape) @@ -291,20 +298,15 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::XlaOp while_result = xla::While(cond_wrapper, *body.computation, init); - auto while_shape_or = builder->GetShape(while_result); - OP_REQUIRES_OK(ctx, while_shape_or.status()); - auto count = xla::ShapeUtil::TupleElementCount(while_shape_or.ValueOrDie()); - int max_index = body.outputs.size() + body.resource_updates.size() - 1; - OP_REQUIRES( - ctx, max_index < count, - errors::Internal("Max tuple element requested (", max_index, - ") needs to be less than tuple size (", count, ")")); - - // Sets non-variable outputs. + // Sets non-variable outputs and determine when resource variables start. + int resource_index = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { if (ctx->input_type(i) != DT_RESOURCE) { ctx->SetOutput(body.input_mapping[i], xla::GetTupleElement(while_result, i)); + ++resource_index; + } else { + break; } } if (has_token_input_output_) { @@ -313,7 +315,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::GetTupleElement(while_result, ctx->num_outputs()); auto shape_or = builder->GetShape(token_output); OP_REQUIRES_OK(ctx, shape_or.status()); - OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), + OP_REQUIRES(ctx, shape_or.ValueOrDie().IsToken(), errors::FailedPrecondition( "Token output is not token type: ", xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); @@ -326,7 +328,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource)); if (update.modified) { - int pos = body.outputs.size() + i; + int pos = resource_index + i; OP_REQUIRES_OK(ctx, resource->SetFromPack( arguments[update.input_index].tensor_array_gradients, diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 688056791f9750e6b22df4b2cd4643de0b780651..1cd5a79171dccd57fc1b7941cdf16417301ff7f8 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -48,7 +48,7 @@ xla::StatusOr XlaScatter( if (indices_are_vectors) { TF_RET_CHECK(!indices_dims.empty()); num_index_dims = indices_dims.back(); - if (num_index_dims > xla::ShapeUtil::Rank(buffer_shape)) { + if (num_index_dims > buffer_shape.rank()) { return errors::InvalidArgument( "The size of the minor dimension of the indices (shape: ", xla::ShapeUtil::HumanString(indices_shape), @@ -140,8 +140,8 @@ xla::StatusOr XlaScatter( ? indices_shape.dimensions_size() - 1 : indices_shape.dimensions_size()); - int64 updates_rank = xla::ShapeUtil::Rank(updates_shape); - int64 buffer_rank = xla::ShapeUtil::Rank(buffer_shape); + int64 updates_rank = updates_shape.rank(); + int64 buffer_rank = buffer_shape.rank(); int64 num_window_dims_in_updates = buffer_rank - num_index_dims; // If the rank of `updates` is 0 and does not match the expected rank of @@ -156,7 +156,7 @@ xla::StatusOr XlaScatter( if (updates_rank == 0 && expected_updates_rank != 0) { new_updates = xla::Broadcast(updates, expected_updates_dims); TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates)); - updates_rank = xla::ShapeUtil::Rank(updates_shape); + updates_rank = updates_shape.rank(); } if (updates_rank > 0) { diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index c0bd172d17c192435ba8ee196f9def0491c0bf5c..06eda41611861060a1f1c4d028b96405d288efdb 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -54,6 +54,9 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, case xla::C64: return xla::ConstantR0(builder, value); break; + case xla::C128: + return xla::ConstantR0(builder, value); + break; default: LOG(FATAL) << "unhandled element type " << type; } @@ -90,6 +93,9 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, case xla::C64: literal = xla::LiteralUtil::CreateR0(value); break; + case xla::C128: + literal = xla::LiteralUtil::CreateR0(value); + break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; case xla::S16: diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 67d08290033361f16dfff42b06af9b253e84963a..749a7c3054a65d6ec9f9dc13f6f4a713ac9d3d5a 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -77,7 +77,7 @@ Status HostTensorsToBorrowingLiteralTuple(absl::Span host_tensors, Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, Tensor* host_tensor) { - TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) && + TF_RET_CHECK(literal.shape().IsArray() && xla::ShapeUtil::ElementsIn(literal.shape()) == host_tensor->NumElements()); xla::PrimitiveType primitive_type; diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index 15f4c38da29507da9e092c1d5725b5f95a81d1b9..44bccfe6474d175beda392ca17dfbcb08c0b1b11 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -49,7 +49,7 @@ using Types = std::pair, std::pair, std::pair>; -TYPED_TEST_CASE(LiteralUtilTest, Types); +TYPED_TEST_SUITE(LiteralUtilTest, Types); TYPED_TEST(LiteralUtilTest, LiteralToQuantizedHostTensor) { using int_type = typename TypeParam::first_type; diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index 4dce0a2102cf9c782850ccc7af4f14b59bd51e53..7140b6a1227a53290c3747892a55886a7f48513b 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -4,7 +4,11 @@ package( licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_library", + "tf_gen_op_wrapper_py", +) cc_library( name = "xla_ops", @@ -24,3 +28,14 @@ tf_gen_op_wrapper_py( ":xla_ops", ], ) + +tf_custom_op_library( + name = "_xla_ops.so", + srcs = [ + "xla_ops.cc", + ], + deps = [ + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + ], +) diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index ab77984684db4525f4d3f42b2c9c0f093c82ec45..af641131ed76a8d6a7291c360302fa17c94af014 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -369,7 +369,11 @@ REGISTER_OP("XlaKeyValueSort") .Output("sorted_values: V") .Attr("K: realnumbertype") .Attr("V: type") - .SetShapeFn(shape_inference::UnchangedShape) + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + c->set_output(1, c->input(1)); + return Status::OK(); + }) .Doc(R"doc( Wraps the XLA Sort operator, documented at https://www.tensorflow.org/performance/xla/operation_semantics#sort diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index fef97b98c376d9df8bbfd9cb6651216895e46bf4..9abdb04d7736e8ff5225688af4759a522d3e7fc7 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -15,6 +15,7 @@ load( "//tensorflow/core:platform/default/build_config.bzl", "tf_py_clif_cc", ) +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") tf_py_clif_cc( name = "xla_op_registry", @@ -27,9 +28,13 @@ tf_py_clif_cc( ], ) -py_library( +tf_custom_op_py_library( name = "xla", srcs = ["xla.py"], + dso = ["//tensorflow/compiler/tf2xla/ops:_xla_ops.so"], + kernels = [ + "//tensorflow/compiler/tf2xla/ops:xla_ops", + ], deps = [ "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", "//tensorflow/compiler/xla:xla_data_proto_py", diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index ff9f1b9ccba2c4f3307890d5aac4ddb6cfaafcd9..c20d6a5fd1f3bd7dad30cb3359d13ed4609a2250 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -77,6 +77,7 @@ CreateResourceOpInfoMap() { add("ResourceScatterMin" , kReadWrite, kVariable); add("ResourceScatterMul" , kReadWrite, kVariable); add("ResourceScatterNdAdd" , kReadWrite, kVariable); + add("ResourceScatterNdSub" , kReadWrite, kVariable); add("ResourceScatterNdUpdate" , kReadWrite, kVariable); add("ResourceScatterSub" , kReadWrite, kVariable); add("ResourceScatterUpdate" , kReadWrite, kVariable); diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index ec604af13867171d558cd7324919fb9531caf460..8997b2f5c68da480e9d4cb1f7ff8776690363392 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -27,7 +27,7 @@ namespace { Status PopulateInfeedLayoutVector(const xla::Shape& shape, std::vector* layouts) { - if (xla::ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { int64 tuple_elements = xla::ShapeUtil::TupleElementCount(shape); for (int64 i = 0; i < tuple_elements; ++i) { const xla::Shape& subshape = @@ -39,23 +39,60 @@ Status PopulateInfeedLayoutVector(const xla::Shape& shape, layouts->push_back(dim); } } else { - layouts->insert(layouts->end(), xla::ShapeUtil::Rank(shape), -1); + layouts->insert(layouts->end(), shape.rank(), -1); } return Status::OK(); } +// Populate the output layout unless the minor_to_major array contains all -1 +// value, in which case the layout is considered missing and the API returns +// false. +xla::StatusOr MakeLayout(absl::Span minor_to_major, + xla::Layout* layout) { + if (std::all_of(minor_to_major.begin(), minor_to_major.end(), + [](int64 dim) { return dim == -1; })) { + return false; + } + std::vector dim_present(minor_to_major.size(), false); + for (auto dim : minor_to_major) { + if (dim < 0 || dim >= minor_to_major.size()) { + return errors::InvalidArgument("Layout dimension out of range: dim=", dim, + " rank=", minor_to_major.size()); + } + if (dim_present[dim]) { + return errors::InvalidArgument("Repeated layout dimension: dim=", dim); + } + dim_present[dim] = true; + } + *layout = xla::LayoutUtil::MakeLayout(minor_to_major); + return true; +} + +Status AssignLayout( + absl::Span minor_to_major, + const std::function& layout_func, + xla::Shape* shape) { + xla::Layout layout; + TF_ASSIGN_OR_RETURN(bool has_layout, MakeLayout(minor_to_major, &layout)); + if (!has_layout && layout_func) { + layout = layout_func(*shape); + } + *shape->mutable_layout() = layout; + return Status::OK(); +} + } // namespace // Convert an XLA Shape into the equivalent TensorFlow shape. Status XLAShapeToTensorShape(const xla::Shape& shape, TensorShape* tensor_shape) { - if (xla::ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { return errors::InvalidArgument("XLA shape ", xla::ShapeUtil::HumanString(shape), " cannot be converted to a TensorShape"); } *tensor_shape = TensorShape(); - for (int i = 0; i < xla::ShapeUtil::Rank(shape); ++i) { + for (int i = 0; i < shape.rank(); ++i) { tensor_shape->AddDim(shape.dimensions(i)); } return Status::OK(); @@ -84,10 +121,64 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout); } -xla::StatusOr> GetInfeedLayoutVector(const xla::Shape& shape) { +xla::StatusOr> GetShapeLayoutVector(const xla::Shape& shape) { std::vector layouts; TF_RETURN_IF_ERROR(PopulateInfeedLayoutVector(shape, &layouts)); return layouts; } +Status GetShapeWithLayout( + const xla::Shape& input_shape, absl::Span minor_to_major, + const std::function& layout_func, + xla::Shape* output_shape) { + if (input_shape.IsTuple()) { + int64 tuple_elements = xla::ShapeUtil::TupleElementCount(input_shape); + std::vector shapes; + shapes.reserve(tuple_elements); + size_t position = 0; + for (int64 i = 0; i < tuple_elements; ++i) { + const xla::Shape& shape = + xla::ShapeUtil::GetTupleElementShape(input_shape, i); + if (shape.IsTuple()) { + return errors::InvalidArgument( + "Nested tuples not supported: ", + xla::ShapeUtil::HumanString(input_shape)); + } + int64 rank = shape.rank(); + if (position + rank > minor_to_major.size()) { + return errors::InvalidArgument( + "Not enough layout attribute elements: position=", position, + " rank=", rank, " elements=", minor_to_major.size()); + } + shapes.push_back(shape); + TF_RETURN_IF_ERROR(AssignLayout( + absl::Span(minor_to_major).subspan(position, rank), + layout_func, &shapes.back())); + position += rank; + + VLOG(4) << "Shape[" << i + << "] = " << xla::ShapeUtil::HumanStringWithLayout(shapes.back()); + } + if (position != minor_to_major.size()) { + return errors::InvalidArgument( + "Too many elements passed in the layout attribute: position=", + position, " size=", minor_to_major.size()); + } + *output_shape = xla::ShapeUtil::MakeTupleShape(shapes); + } else { + int64 rank = input_shape.rank(); + if (rank != minor_to_major.size()) { + return errors::InvalidArgument( + "Wrong number of layout attribute elements: rank=", rank, + " elements=", minor_to_major.size()); + } + *output_shape = input_shape; + TF_RETURN_IF_ERROR(AssignLayout(minor_to_major, layout_func, output_shape)); + + VLOG(4) << "Shape[] = " + << xla::ShapeUtil::HumanStringWithLayout(*output_shape); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h index cf52bf46e7c2a237d57f4c87e7d6efbf3fa9b1c2..e775c4462c3dc15cf4b8d9e8d8e7d9a61e024cd0 100644 --- a/tensorflow/compiler/tf2xla/shape_util.h +++ b/tensorflow/compiler/tf2xla/shape_util.h @@ -45,12 +45,23 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, const TensorShape& tensor_shape); // Given an XLA shape with layouts, builds a layout vector in the form able to -// be fed to an InfeedEnqueue/InfeedEnqueueTuple ops. +// be fed to ops like InfeedEnqueue/InfeedEnqueueTuple/XRTAllocateV2/.... // THe returned vector is a linearized sequence of the minor-to-major values of // the layouts held within the input shape. // In case the input shape is a tuple, the minor-to-major values will be in the // order of the tuple elements within the tuple shape. -xla::StatusOr> GetInfeedLayoutVector(const xla::Shape& shape); +// If a shape (or a subshape of a tuple shape) has missing layout, a rank long +// sequence of -1 values will be emittted. +xla::StatusOr> GetShapeLayoutVector(const xla::Shape& shape); + +// Given the input shape and a linearized sequence of the minor-to-major values +// of the layouts, create the output shape by rewriting the input shape layouts. +// If a layout is missing (has -1 values) for a matching tuple subshape, the +// layout_func will be called, if not nullptr. +Status GetShapeWithLayout( + const xla::Shape& input_shape, absl::Span minor_to_major, + const std::function& layout_func, + xla::Shape* output_shape); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc index b62f8e9115229ac35c657d374c68336f1168ff77..412f31adbb7df52b2d6933be054cc6d40947dc44 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.cc +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -26,6 +26,49 @@ const char kXlaTokenArgNodeName[] = "_xla_token_arg_node"; const char kXlaHasHostTransferAttrName[] = "_xla_has_host_transfer"; +Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { + if (!HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) { + return errors::InvalidArgument("Node ", node->DebugString(), + " does not have attribute ", + kXlaHasHostTransferAttrName); + } + + if (node->type_string() == "_XlaRecvAtHost" || + node->type_string() == "_XlaSendFromHost") { + node->ClearAttr("device_ordinal"); + node->AddAttr("device_ordinal", device_ordinal); + } else if (node->type_string() == "If") { + AttrValue device_ordinal_value; + device_ordinal_value.set_i(device_ordinal); + for (const string& attr_name : + std::vector{"then_branch", "else_branch"}) { + NameAttrList branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), attr_name, &branch_func)); + (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; + node->ClearAttr(attr_name); + node->AddAttr(attr_name, branch_func); + } + } else if (node->type_string() == "While") { + AttrValue device_ordinal_value; + device_ordinal_value.set_i(device_ordinal); + for (const string& attr_name : std::vector{"cond", "body"}) { + NameAttrList branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), attr_name, &branch_func)); + (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; + node->ClearAttr(attr_name); + node->AddAttr(attr_name, branch_func); + } + } else if (HasNodeAttr(node->def(), "device_ordinal")) { + // Function call node containing outside compilation. + node->ClearAttr("device_ordinal"); + node->AddAttr("device_ordinal", device_ordinal); + } else { + return errors::Internal("Unknown node type to set 'device_ordinal': ", + node->DebugString()); + } + return Status::OK(); +} + std::set CalculateTokenInputsForOutputToken(const Graph& g) { std::set results; Node* first_side_effecting_node_on_path = nullptr; diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h index 7081b362c36c4785164b29003a5f89cd73bcf3af..75e1f253fb08ae61b0336a8783b7449c69197dd1 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.h +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -38,6 +38,10 @@ extern const char kXlaTokenArgNodeName[]; // This node have XlaRecvAtHost/XlaSendFromHost in its associated functions. extern const char kXlaHasHostTransferAttrName[]; +// Sets device ordinal attribute for nodes with attribute +// `kXlaHasHostTransferAttrName`. +Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal); + // Calculates side-effect dependencies for the graph's token output. // Returns a set of node names representing these dependencies. std::set CalculateTokenInputsForOutputToken(const Graph& g); diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 9fac16a9700419b189bf5393c2b8bd7d76c6c1cc..cf48576ec2746fb29779633275eac4c638b91e45 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -243,7 +243,9 @@ Status CreateXlaArgs(const Graph& graph, XlaCompiler::Argument arg; arg.kind = XlaCompiler::Argument::kParameter; TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape)); + TensorShape shape; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape)); + arg.shape = shape; TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name)); xla_args->push_back(arg); } diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index d00b1376620c0c9d112c7d7426758f6d3f25e86f..732f957d7329c93ad104dacf5190948fbfd7974b 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -69,6 +69,9 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_COMPLEX64: *type = xla::C64; return Status::OK(); + case tensorflow::DT_COMPLEX128: + *type = xla::C128; + return Status::OK(); default: return errors::InvalidArgument( "Unsupported type in DataTypeToPrimitiveType ", diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index c7341cf8b9e8d7a06fd304ae8766420d20f0c16e..de2e485a47c18ae8e58a06aba408dbb61a30d00a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -59,45 +59,8 @@ class XlaCompiledCpuFunction { // AOT this is backed by data compiled into the object file. // // The contents of StaticData are XLA-internal implementation details and - // should not be relied on by clients. - // - // TODO(sanjoy): Come up with a cleaner way to express the contraint we want - // here: generated XlaCompiledCpuFunction subclasses should be able to create - // instances of StaticData but only XlaCompiledCpuFunction should be able to - // read from StaticData instances. + // should not be relied on by clients (and therefore are private). class StaticData { - public: - void set_raw_function(RawFunction raw_function) { - raw_function_ = raw_function; - } - void set_buffer_infos( - const cpu_function_runtime::BufferInfo* buffer_infos) { - buffer_infos_ = buffer_infos; - } - void set_num_buffers(size_t num_buffers) { num_buffers_ = num_buffers; } - void set_arg_index_table(const int32* arg_index_table) { - arg_index_table_ = arg_index_table; - } - void set_num_args(int64 num_args) { num_args_ = num_args; } - void set_result_index(size_t result_index) { result_index_ = result_index; } - void set_arg_names(const char** arg_names) { arg_names_ = arg_names; } - void set_result_names(const char** result_names) { - result_names_ = result_names; - } - void set_program_shape(const xla::ProgramShapeProto* program_shape) { - program_shape_ = program_shape; - } - const xla::HloProfilePrinterData* hlo_profile_printer_data() const { - return hlo_profile_printer_data_; - } - void set_hlo_profile_printer_data( - const xla::HloProfilePrinterData* hlo_profile_printer_data) { - hlo_profile_printer_data_ = hlo_profile_printer_data; - } - void set_profile_counters_size(int64 profile_counters_size) { - profile_counters_size_ = profile_counters_size; - } - private: // The raw function to call. RawFunction raw_function_; @@ -134,7 +97,8 @@ class XlaCompiledCpuFunction { // declared so we don't have access to that information here. int64 profile_counters_size_ = 0; - // Only XlaCompiledCpuFunction is allowed to read the above fields. + // Only XlaCompiledCpuFunction is allowed to read and write the above + // fields. friend class XlaCompiledCpuFunction; }; @@ -148,7 +112,7 @@ class XlaCompiledCpuFunction { RESULTS_PROFILES_AND_TEMPS_ONLY, }; - XlaCompiledCpuFunction( + explicit XlaCompiledCpuFunction( const StaticData& static_data, AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS); virtual ~XlaCompiledCpuFunction(); @@ -280,6 +244,76 @@ class XlaCompiledCpuFunction { return *hlo_profile_printer_data_; } + protected: + // --------------------------------------------------------------------------- + // Accessors for reading from and writing to instances of `StaticData`. + // + // Classes generated by tfcompile can call these because the generated classes + // inherit from `XlaCompiledCpuFunction`. `XlaJitCompiledCpuFunction` can + // call these because it is explicitly added as a friend. + + static void set_static_data_raw_function(StaticData* static_data, + RawFunction raw_function) { + static_data->raw_function_ = raw_function; + } + + static void set_static_data_buffer_infos( + StaticData* static_data, + const cpu_function_runtime::BufferInfo* buffer_infos) { + static_data->buffer_infos_ = buffer_infos; + } + + static void set_static_data_num_buffers(StaticData* static_data, + size_t num_buffers) { + static_data->num_buffers_ = num_buffers; + } + + static void set_static_data_arg_index_table(StaticData* static_data, + const int32* arg_index_table) { + static_data->arg_index_table_ = arg_index_table; + } + + static void set_static_data_num_args(StaticData* static_data, + int64 num_args) { + static_data->num_args_ = num_args; + } + + static void set_static_data_result_index(StaticData* static_data, + size_t result_index) { + static_data->result_index_ = result_index; + } + + static void set_static_data_arg_names(StaticData* static_data, + const char** arg_names) { + static_data->arg_names_ = arg_names; + } + + static void set_static_data_result_names(StaticData* static_data, + const char** result_names) { + static_data->result_names_ = result_names; + } + + static void set_static_data_program_shape( + StaticData* static_data, const xla::ProgramShapeProto* program_shape) { + static_data->program_shape_ = program_shape; + } + + static void set_static_data_hlo_profile_printer_data( + StaticData* static_data, + const xla::HloProfilePrinterData* hlo_profile_printer_data) { + static_data->hlo_profile_printer_data_ = hlo_profile_printer_data; + } + + static const xla::HloProfilePrinterData* + get_static_data_hlo_profile_printer_data(StaticData* static_data) { + return static_data->hlo_profile_printer_data_; + } + + static void set_static_data_profile_counters_size( + StaticData* static_data, int64 profile_counters_size) { + static_data->profile_counters_size_ = profile_counters_size; + } + private: const RawFunction raw_function_; const size_t result_index_; @@ -313,6 +347,10 @@ class XlaCompiledCpuFunction { const char** result_names_ = nullptr; const xla::ProgramShapeProto* program_shape_ = nullptr; const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; + + // Add `XlaJitCompiledCpuFunction` as a friend so that it can access the + // `set_static_data_*` static methods above. + friend class XlaJitCompiledCpuFunction; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index ee461a3c07d4db514c7697e005a9371be4b54dd0..1f9cfcdd246f36bd7e0325bca34c7480d4ce2843 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" @@ -192,6 +193,8 @@ Status BuildComputation( output.shape = output.constant_value.shape(); break; + case XlaExpression::Kind::kTensorList: + TF_FALLTHROUGH_INTENDED; case XlaExpression::Kind::kXlaOp: { output.is_constant = false; TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape()); @@ -333,8 +336,21 @@ bool XlaCompiler::Argument::operator==( other.tensor_array_gradients)) { return false; } - if (shape != other.shape) { - return false; + if (absl::holds_alternative(shape)) { + if (!absl::holds_alternative(other.shape)) { + return false; + } + if (!xla::Shape::Equal()(absl::get(shape), + absl::get(other.shape))) { + return false; + } + } else { + if (!absl::holds_alternative(other.shape)) { + return false; + } + if (absl::get(shape) != absl::get(other.shape)) { + return false; + } } if (constant_value.shape() != other.constant_value.shape()) { return false; @@ -348,7 +364,7 @@ string XlaCompiler::Argument::HumanString() const { common = absl::StrCat(" name=", name); } absl::StrAppend(&common, " type=", DataTypeString(type), - " shape=", shape.DebugString()); + " shape=", ShapeHumanString()); switch (kind) { case kInvalid: return "invalid"; @@ -375,6 +391,23 @@ string XlaCompiler::Argument::HumanString() const { } } +std::vector XlaCompiler::Argument::DimensionSizes() const { + if (absl::holds_alternative(shape)) { + return xla::InlinedVectorToVector( + absl::get(shape).dim_sizes()); + } else { + return absl::get(shape).dimensions(); + } +} + +string XlaCompiler::Argument::ShapeHumanString() const { + if (absl::holds_alternative(shape)) { + return absl::get(shape).DebugString(); + } else { + return absl::get(shape).DebugString(); + } +} + XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(options), initialization_status_(Status::OK()), @@ -462,8 +495,34 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { opts.set_do_function_inlining(true); opts.set_do_constant_folding(true); GraphOptimizer optimizer(opts); + // Do not constant fold nodes that output DT_VARIANT type tensors. + // XLA does not support Const nodes of Variant type since it needs + // to know the original ops to be able to compile them to the relevant + // XLA form. + // TODO(srbs): This filter is a little conservative. E.g. a subgraph of + // the form: + // Const + // | + // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op + // | + // (Discard popped list) + // + // Would have been reduced to "Const -> Op" without this filter. + // However since we are only allowed to specify the filter at the "Node" + // level there is no good way to allow the above behavior. So we + // disallow any sort of constant folding on Variant nodes for now. + auto cf_consider_fn = [](const Node* n) { + for (const auto& output_arg : n->op_def().output_arg()) { + if (output_arg.type() == DT_VARIANT) { + return false; + } + } + return true; + }; + GraphOptimizer::Options graph_optimizer_options; + graph_optimizer_options.cf_consider_fn = cf_consider_fn; optimizer.Optimize(flib_runtime_, flib_runtime_->env(), - /*device=*/nullptr, &graph, /*shape_map=*/nullptr); + /*device=*/nullptr, &graph, graph_optimizer_options); return graph; } @@ -548,11 +607,22 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, LOG(FATAL) << "Unreachable case"; case XlaCompiler::Argument::kParameter: { if (is_entry_computation) { - TF_ASSIGN_OR_RETURN( - *xla_shape, options_.shape_representation_fn(arg.shape, arg.type)); + TensorShape shape; + if (absl::holds_alternative(arg.shape)) { + shape = absl::get(arg.shape); + } else { + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(absl::get(arg.shape), &shape)); + } + TF_ASSIGN_OR_RETURN(*xla_shape, + options_.shape_representation_fn(shape, arg.type)); } else { - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(arg.type, arg.shape, xla_shape)); + if (absl::holds_alternative(arg.shape)) { + *xla_shape = absl::get(arg.shape); + } else { + TF_RETURN_IF_ERROR(TensorShapeToXLAShape( + arg.type, absl::get(arg.shape), xla_shape)); + } } return Status::OK(); } @@ -561,8 +631,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, switch (arg.resource_kind) { case XlaResource::kVariable: { - TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn( - arg.shape, arg.type)); + TF_RET_CHECK(absl::holds_alternative(arg.shape)); + TF_ASSIGN_OR_RETURN(*xla_shape, + options_.shape_representation_fn( + absl::get(arg.shape), arg.type)); return Status::OK(); } @@ -571,9 +643,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, return errors::InvalidArgument( "Negative max_array_size in XLAShapeForArgument"); } + TF_RET_CHECK(absl::holds_alternative(arg.shape)); TensorShape shape; shape.AddDim(arg.max_array_size); - shape.AppendShape(arg.shape); + shape.AppendShape(absl::get(arg.shape)); TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape)); if (!arg.tensor_array_gradients.empty()) { @@ -588,9 +661,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, return errors::InvalidArgument( "Negative max_array_size in XLAShapeForArgument"); } + TF_RET_CHECK(absl::holds_alternative(arg.shape)); TensorShape shape; shape.AddDim(arg.max_array_size); - shape.AppendShape(arg.shape); + shape.AppendShape(absl::get(arg.shape)); xla::Shape buffer_shape; TF_RETURN_IF_ERROR( TensorShapeToXLAShape(arg.type, shape, &buffer_shape)); @@ -620,14 +694,15 @@ Status XlaCompiler::BuildArguments( bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, const std::map& arg_cores, std::vector* arg_expressions, - std::vector* input_mapping, std::vector* input_shapes, + std::vector* input_to_args, std::vector* input_shapes, bool is_entry_computation) { arg_expressions->resize(args.size()); // Argument numbers of arguments and resources that are to be passed to the - // XLA computation as runtime parameters. - input_mapping->clear(); - input_mapping->reserve(args.size()); + // XLA computation as runtime parameters. `input_to_args[a] = b` means that + // the a'th XLA input corresponds to the b'th original arg indexes. + input_to_args->clear(); + input_to_args->reserve(args.size()); // Fills in constant arguments, and computes non-constant argument order. for (std::vector::size_type i = 0; i < args.size(); @@ -637,24 +712,25 @@ Status XlaCompiler::BuildArguments( switch (arg.kind) { case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid); + TF_RET_CHECK(absl::holds_alternative(arg.shape)); // TODO(phawkins): this code assumes that resource arguments do not // alias. XlaResource* resource = context->AddResource(absl::make_unique( - arg.resource_kind, i, arg.name, arg.type, arg.shape, - xla::XlaOp(), + arg.resource_kind, i, arg.name, arg.type, + absl::get(arg.shape), xla::XlaOp(), /*max_array_size=*/arg.max_array_size, /*tensor_array_gradients=*/arg.tensor_array_gradients, /*tensor_array_multiple_writes_aggregate=*/true)); arg_expression = XlaExpression::Resource(resource); if (arg.initialized) { - input_mapping->push_back(i); + input_to_args->push_back(i); } break; } case XlaCompiler::Argument::kParameter: case XlaCompiler::Argument::kToken: { - input_mapping->push_back(i); + input_to_args->push_back(i); break; } case XlaCompiler::Argument::kConstant: @@ -666,15 +742,23 @@ Status XlaCompiler::BuildArguments( } } - if (input_mapping->empty()) { + if (input_to_args->empty()) { return Status::OK(); } - std::vector arg_shapes(input_mapping->size()); - for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + // `arg_to_inputs[c] = d` means that the c'th original arg index corresponds + // to the d'th XLA input. Note that the value -1 corresponds to constants, or + // other args that don't correspond to an input. + std::vector arg_to_inputs(args.size(), -1); + for (int i = 0; i < input_to_args->size(); i++) { + arg_to_inputs[input_to_args->at(i)] = i; + } + + std::vector arg_shapes(input_to_args->size()); + for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { // Computes the shapes of non-constant arguments. TF_RETURN_IF_ERROR(XLAShapeForArgument( - args[(*input_mapping)[i]], is_entry_computation, &arg_shapes[i])); + args[(*input_to_args)[i]], is_entry_computation, &arg_shapes[i])); } if (use_tuple_arg) { @@ -691,13 +775,13 @@ Status XlaCompiler::BuildArguments( builder->SetOpMetadata(arg_metadata); // Build parameter handles for non-constant arguments. - std::vector arg_handles(input_mapping->size()); + std::vector arg_handles(input_to_args->size()); if (use_tuple_arg) { xla::XlaOp tuple; if (is_entry_computation) { xla::OpSharding tuple_sharding; tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); - for (int64 parameter : *input_mapping) { + for (int64 parameter : *input_to_args) { auto it = arg_cores.find(parameter); const int core = it == arg_cores.end() ? 0 : it->second; *tuple_sharding.add_tuple_shardings() = @@ -709,7 +793,19 @@ Status XlaCompiler::BuildArguments( } else { tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } - for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + + for (int i = 0; i < input_to_args->size(); ++i) { + const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; + for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { + int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); + TF_RETURN_IF_ERROR(builder->SetDynamicBinding( + /*dynamic_size_param_num=*/0, {dynamic_size_param_index}, + /*target_param_num=*/0, /*target_param_index=*/{i}, + dim_and_arg_num.first)); + } + } + + for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { auto it = arg_cores.find(i); const int core = it == arg_cores.end() ? -1 : it->second; xla::XlaScopedShardingAssignment assign_sharding( @@ -718,7 +814,7 @@ Status XlaCompiler::BuildArguments( arg_handles[i] = xla::GetTupleElement(tuple, i); } } else { - for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { auto it = arg_cores.find(i); const int core = it == arg_cores.end() ? -1 : it->second; xla::XlaScopedShardingAssignment assign_sharding( @@ -727,6 +823,17 @@ Status XlaCompiler::BuildArguments( arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i], absl::StrCat("arg", i)); } + + for (int i = 0; i < input_to_args->size(); ++i) { + const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; + for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { + int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); + TF_RETURN_IF_ERROR(builder->SetDynamicBinding( + /*dynamic_size_param_num=*/dynamic_size_param_index, {}, + /*target_param_num=*/i, /*target_param_index=*/{}, + dim_and_arg_num.first)); + } + } } builder->ClearOpMetadata(); @@ -734,12 +841,12 @@ Status XlaCompiler::BuildArguments( // Fill in the handles in non-constant arguments, and reshape parameters // back to their correct shapes. VLOG(2) << "XLA computation inputs:"; - for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { - const XlaCompiler::Argument& arg = args[input_mapping->at(i)]; + for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { + const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; VLOG(2) << " XLA arg " << i << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i]) - << " name: " << arg.name << " TF arg " << input_mapping->at(i); - XlaExpression& arg_expression = (*arg_expressions)[input_mapping->at(i)]; + << " name: " << arg.name << " TF arg " << input_to_args->at(i); + XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)]; switch (arg.kind) { case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); @@ -756,7 +863,7 @@ Status XlaCompiler::BuildArguments( // return values of functions, and then reshape unconditionally. if (is_entry_computation) { arg_expression = XlaExpression::XlaOp( - xla::Reshape(arg_handles[i], arg.shape.dim_sizes()), arg.type); + xla::Reshape(arg_handles[i], arg.DimensionSizes()), arg.type); } else { arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 0d801b73a8c2651305328384377751254ecaa41d..ad3144b41bdf3fc8b75ab5230e8e128df2962884 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "absl/types/variant.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" @@ -124,7 +125,8 @@ class XlaCompiler { DataType type = DT_INVALID; // The shape of the argument. For: - // * a parameter: the shape of the parameter. + // * a parameter: the shape of the parameter. We allow setting the xla shape + // if known. This helps avoid conversions to and from TensorShape. // * a constant: ignored; the shape given by constant_value is used // instead. // * an uninitialized resource: ignored. We don't yet know the shape of an @@ -133,7 +135,7 @@ class XlaCompiler { // * an initialized TensorArray or Stack resource: the shape of an entry in // the TensorArray/Stack. Note this is the size of a single entry, not the // XLA data structure that represents the complete stack/array. - TensorShape shape; + absl::variant shape; // The value of the argument, if it is a compile-time constant. Must be a // host-memory tensor. @@ -157,10 +159,20 @@ class XlaCompiler { // as `tensor_array_gradients`. std::set tensor_array_gradients; + // dynamic dims to arg number map. Empty if no dynamic shapes. + std::map dynamic_dim_to_arg_num_map; + bool is_pad_arg = false; + bool operator==(const Argument& other) const; // Returns a human-readable summary of the argument. string HumanString() const; + + // Returns the dimension sizes for either TensorShape or xla::Shape. + std::vector DimensionSizes() const; + + // Returns the human-readable string for either TensorShape or xla::Shape. + string ShapeHumanString() const; }; // Options pertaining to an individual call to CompileGraph() or @@ -420,7 +432,7 @@ class XlaCompiler { XlaContext* context, const std::map& arg_cores, std::vector* arg_expressions, - std::vector* input_mapping, + std::vector* input_to_args, std::vector* input_shapes, bool is_entry_computation); diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index fe2a5f5b0c9ea6b5f2bb71df836fdcabf9a0cf23..492010f7317d32a8a620147cd2cd9356d4f13fde 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -82,7 +82,7 @@ namespace { // compiled kernels. class DummyResourceForTest : public ResourceBase { public: - string DebugString() override { return "dummy"; } + string DebugString() const override { return "dummy"; } void Increment() { ++value_; } int Get() { return value_; } @@ -1362,7 +1362,7 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) { TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), args, &result)); EXPECT_EQ(result.xla_input_shapes.size(), 1); - EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); + EXPECT_TRUE(result.xla_output_shape.IsTuple()); EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1); } { @@ -1380,11 +1380,11 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) { TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), args, &result)); EXPECT_EQ(result.xla_input_shapes.size(), 2); - EXPECT_TRUE(xla::ShapeUtil::IsToken(result.xla_input_shapes[1])); - EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); + EXPECT_TRUE(result.xla_input_shapes[1].IsToken()); + EXPECT_TRUE(result.xla_output_shape.IsTuple()); EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 2); - EXPECT_TRUE(xla::ShapeUtil::IsToken( - xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 1))); + EXPECT_TRUE(xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 1) + .IsToken()); } } diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index a69af70503376b6c0905deb8980abdc3254a6e47..6139bf3cea0790c2697130a993e92be96c81848b 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -61,7 +61,7 @@ void XlaContext::set_args(std::vector args) { XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder) : compiler_(compiler), builder_(builder) {} -string XlaContext::DebugString() { return "XLA JIT context"; } +string XlaContext::DebugString() const { return "XLA JIT context"; } void XlaContext::SetRetval(int index, const XlaExpression& expression) { if (retvals_.size() <= index) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 0767d1faac14cedb8666f6cc37175eb7b55f6158..eb4ad3fe6a14b42a4df2c73c71cb6df1331fd796 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -47,7 +47,7 @@ class XlaContext : public ResourceBase { XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder); // Virtual method defined by ResourceBase. - string DebugString() override; + string DebugString() const override; XlaCompiler* compiler() const { return compiler_; } diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc index ca0309166b7c73d1a5a818091e2a30fa112a4de4..3d228c92adcbe3d093a4fe70d157e57ab3e80c80 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.cc +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -46,6 +46,14 @@ XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) { return e; } +XlaExpression XlaExpression::TensorList(xla::XlaOp tensor_list) { + XlaExpression e; + e.kind_ = Kind::kTensorList; + e.dtype_ = DT_VARIANT; + e.handle_ = tensor_list; + return e; +} + XlaExpression XlaExpression::Resource(XlaResource* resource) { XlaExpression e; e.kind_ = Kind::kResource; @@ -64,6 +72,8 @@ string XlaExpression::HumanString() const { return "xla_op"; case Kind::kResource: return "resource"; + case Kind::kTensorList: + return "tensor_list"; } } @@ -76,6 +86,8 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const { HostTensorToBorrowingLiteral(constant_value_, &literal)); return xla::ConstantLiteral(builder, literal); } + case Kind::kTensorList: + TF_FALLTHROUGH_INTENDED; case Kind::kXlaOp: if (builder != handle_.builder()) { return errors::InvalidArgument( @@ -96,7 +108,10 @@ xla::StatusOr> XlaExpression::ResolveConstant( return {constant_value()}; case Kind::kXlaOp: break; + case Kind::kTensorList: + TF_FALLTHROUGH_INTENDED; case Kind::kResource: + TF_FALLTHROUGH_INTENDED; case Kind::kInvalid: return errors::InvalidArgument( "ResolveConstant called on XlaExpression: ", HumanString()); @@ -134,6 +149,8 @@ xla::StatusOr XlaExpression::GetShape() const { TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape)); return shape; } + case Kind::kTensorList: + return TensorShape({}); case Kind::kResource: return TensorShape({}); case Kind::kInvalid: diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h index bed6761d362a98d344003c1edea342e68c31ef07..ac0232d8924cf2c9e35ad3f0772a3a2adc18af87 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.h +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -32,11 +32,16 @@ namespace tensorflow { // * a constant tensor. // * an xla::XlaOp, representing a symbolic XLA value. // * a resource, e.g., a variable, represented as an XlaResource pointer. +// * a tensor list, represented by a tuple of tensors and the list length. // // Constant tensors are mostly an optimization to avoid passing large constants // to XLA, but are also sometimes used to represent tensors that have no XLA // representation, for example, DT_STRING tensors. A canonical use case might be // an error message string. +// +// Tensor lists are very similar to xla::XlaOp, however they require some +// specific logic around shape management since the tuples are not supported by +// TensorFlow. class XlaExpression { public: enum class Kind { @@ -44,6 +49,7 @@ class XlaExpression { kConstant, kXlaOp, kResource, + kTensorList, }; XlaExpression(); @@ -62,6 +68,9 @@ class XlaExpression { // be derived from the XLA type. static XlaExpression XlaOp(xla::XlaOp value, DataType dtype); + // Builds a tensor list expression. + static XlaExpression TensorList(xla::XlaOp tensor_list); + // Builds a resource expression. static XlaExpression Resource(XlaResource* resource); @@ -100,7 +109,8 @@ class XlaExpression { DataType dtype_ = DT_INVALID; - // The XLA handle of the expression's computation, if kind_ == kXlaOp. + // The XLA handle of the expression's computation, if kind_ == kXlaOp or + // a tuple expression if kind_ == kTensorList. xla::XlaOp handle_; // The value of the constant, if kind_ == kConstant. diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index c2c0751211180c3715a19d6c78e34659fd18914e..04a5d934064a9083a41cc210b48df65bbc862fff 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -34,63 +34,6 @@ limitations under the License. namespace tensorflow { -namespace { - -xla::XlaOp ArgMinMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis, - bool is_min) { - xla::XlaBuilder* builder = input.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); - xla::XlaOp init_value; - xla::XlaComputation reducer; - if (is_min) { - init_value = xla::MaxValue(builder, input_shape.element_type()); - reducer = - xla::CreateScalarMinComputation(input_shape.element_type(), builder); - } else { - init_value = xla::MinValue(builder, input_shape.element_type()); - reducer = - xla::CreateScalarMaxComputation(input_shape.element_type(), builder); - } - - xla::XlaOp input_max = xla::Reduce(input, init_value, reducer, - /*dimensions_to_reduce=*/{axis}); - std::vector broadcast_dims(xla::ShapeUtil::Rank(input_shape) - 1); - std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); - std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - // Compute a mask that has 1s for elements equal to the maximum. - xla::XlaOp partial_mask = xla::ConvertElementType( - xla::Eq(input, input_max, broadcast_dims), output_type); - - // In order to make identity elements for a bitwise And, we: - // Left shift the 1 to the leftmost bit, yielding 0x10...0 - // Arithmetic right shift the 1 back to the rightmost bit, yielding - // 0xFF...F - int32 bits_in_type = - xla::ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1; - xla::XlaOp shift_amount = - xla::ConstantR0WithType(builder, output_type, bits_in_type); - xla::XlaOp full_mask = xla::ShiftRightArithmetic( - xla::ShiftLeft(partial_mask, shift_amount), shift_amount); - - // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its - // index. - - const int64 axis_size = xla::ShapeUtil::GetDimension(input_shape, axis); - xla::XlaOp iota = xla::Iota(builder, output_type, axis_size); - xla::XlaOp product = - xla::And(full_mask, iota, /*broadcast_dimensions=*/{axis}); - - // If there are multiple maximum elements, choose the one with the highest - // index. - return xla::Reduce(product, xla::MinValue(builder, output_type), - xla::CreateScalarMaxComputation(output_type, builder), - /*dimensions_to_reduce=*/{axis}); - }); -} - -} // namespace - xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); @@ -120,7 +63,7 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, /* static */ Status XlaHelpers::ReshapeLiteral( const xla::Literal& input, absl::Span dimensions, xla::Literal* output) { - if (xla::ShapeUtil::IsTuple(input.shape())) { + if (input.shape().IsTuple()) { return errors::InvalidArgument("ReshapeLiteral does not support tuples."); } xla::Shape shape = @@ -148,16 +91,6 @@ static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) { return linspace; } -xla::XlaOp XlaHelpers::ArgMax(xla::XlaOp input, xla::PrimitiveType output_type, - int axis) { - return ArgMinMax(input, output_type, axis, /*is_min=*/false); -} - -xla::XlaOp XlaHelpers::ArgMin(xla::XlaOp input, xla::PrimitiveType output_type, - int axis) { - return ArgMinMax(input, output_type, axis, /*is_min=*/true); -} - Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, DataType index_type, const TensorShape& indices_shape, const xla::XlaOp& indices, const xla::XlaOp& on_value, diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 4858dfee55a393d04cd2af83916eeb40820ee368..490923526bd3acd4b167ccb3faff1d6c9e631131 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -53,16 +53,6 @@ class XlaHelpers { absl::Span shape, xla::Literal* output); - // Returns the argmax of `input` along `axis`. `output_type` is the type to - // use for the output. - static xla::XlaOp ArgMax(xla::XlaOp input, xla::PrimitiveType output_type, - int axis); - - // Returns the argmin of `input` along `axis`. `output_type` is the type to - // use for the output. - static xla::XlaOp ArgMin(xla::XlaOp input, xla::PrimitiveType output_type, - int axis); - // Converts `indices` into a one-hot representation. `depth` is the size // of the new axis to add. `axis` is the position at which to add the new // axis. `indices_shape` is the shape of `indices`. `on_value` and diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index fabbcd04fed96ad814d04c2df9394f43bfe0cf99..884dc45cb11b18ae557c3da3f4192b3805cb7980 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -135,24 +135,34 @@ XlaJitCompiledCpuFunction::Compile( jit->arg_index_table_ = std::move(arg_index_table); jit->program_shape_ = absl::make_unique(program_shape->ToProto()); - jit->static_data_.set_raw_function(raw_function); - jit->static_data_.set_buffer_infos(jit->buffer_infos_.data()); - jit->static_data_.set_num_buffers(jit->buffer_infos_.size()); - jit->static_data_.set_arg_index_table(jit->arg_index_table_.data()); - jit->static_data_.set_num_args(jit->arg_index_table_.size()); - jit->static_data_.set_result_index(result_index); + XlaCompiledCpuFunction::set_static_data_raw_function(&jit->static_data_, + raw_function); + XlaCompiledCpuFunction::set_static_data_buffer_infos( + &jit->static_data_, jit->buffer_infos_.data()); + XlaCompiledCpuFunction::set_static_data_num_buffers( + &jit->static_data_, jit->buffer_infos_.size()); + XlaCompiledCpuFunction::set_static_data_arg_index_table( + &jit->static_data_, jit->arg_index_table_.data()); + XlaCompiledCpuFunction::set_static_data_num_args( + &jit->static_data_, jit->arg_index_table_.size()); + XlaCompiledCpuFunction::set_static_data_result_index(&jit->static_data_, + result_index); // Optional metadata is collected and set below. CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_); CollectNames(config.fetch(), &jit->nonempty_result_names_, &jit->result_names_); - jit->static_data_.set_arg_names(jit->arg_names_.data()); - jit->static_data_.set_result_names(jit->result_names_.data()); - jit->static_data_.set_program_shape(jit->program_shape_.get()); + XlaCompiledCpuFunction::set_static_data_arg_names(&jit->static_data_, + jit->arg_names_.data()); + XlaCompiledCpuFunction::set_static_data_result_names( + &jit->static_data_, jit->result_names_.data()); + XlaCompiledCpuFunction::set_static_data_program_shape( + &jit->static_data_, jit->program_shape_.get()); if (cpu_executable->hlo_profiling_enabled()) { - jit->static_data_.set_hlo_profile_printer_data( - &cpu_executable->hlo_profile_printer_data()); - jit->static_data_.set_profile_counters_size( + XlaCompiledCpuFunction::set_static_data_hlo_profile_printer_data( + &jit->static_data_, &cpu_executable->hlo_profile_printer_data()); + XlaCompiledCpuFunction::set_static_data_profile_counters_size( + &jit->static_data_, cpu_executable->hlo_profile_printer_data().profile_counters_size()); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 58808c76de6330a6b28e21dbdead03dea25847f6..78bc2c94425e00c2b26058daf609d71f1853664e 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -93,7 +93,7 @@ TensorShape XlaOpKernelContext::InputShape(absl::string_view name) { } DataType XlaOpKernelContext::input_type(int index) const { - return context_->input(index).dtype(); + return context_->input_dtype(index); } DataType XlaOpKernelContext::InputType(absl::string_view name) { @@ -178,7 +178,7 @@ Status XlaOpKernelContext::ConstantInputReshaped( // Converts an int32 or int64 scalar literal to an int64. static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, int64* out) { - if (xla::ShapeUtil::Rank(literal.shape()) != 0) { + if (literal.shape().rank() != 0) { return errors::InvalidArgument("value is not a scalar"); } if (literal.shape().element_type() == xla::S32) { @@ -194,7 +194,7 @@ static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, // Converts an float32 or float64 scalar literal to a float64. static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal, double* out) { - if (xla::ShapeUtil::Rank(literal.shape()) != 0) { + if (literal.shape().rank() != 0) { return errors::InvalidArgument("value is not a scalar"); } if (literal.shape().element_type() == xla::F32) { @@ -228,8 +228,9 @@ Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) { // Converts an int32 or int64 1D literal to an int64 vector. static Status LiteralToInt64Vector(const xla::LiteralSlice& literal, std::vector* out) { - if (xla::ShapeUtil::Rank(literal.shape()) != 1) { - return errors::InvalidArgument("value is not 1D"); + if (literal.shape().rank() != 1) { + return errors::InvalidArgument("value is not 1D, rank: ", + literal.shape().rank()); } int64 size = xla::ShapeUtil::ElementsIn(literal.shape()); if (literal.shape().element_type() == xla::S32) { @@ -353,8 +354,8 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, TF_RET_CHECK(variable != nullptr); TF_RET_CHECK(variable->kind() == XlaResource::kVariable); if (!variable->initialized()) { - return errors::InvalidArgument("Read of uninitialized variable ", - variable->name()); + return errors::FailedPrecondition("Read of uninitialized variable ", + variable->name()); } if (variable->type() != type) { return errors::InvalidArgument( @@ -456,6 +457,11 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { SetOutputExpression(index, XlaExpression::Constant(constant)); } +void XlaOpKernelContext::SetTensorListOutput(int index, + const xla::XlaOp& handle) { + SetOutputExpression(index, XlaExpression::TensorList(handle)); +} + void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { SetOutputExpression(index, XlaExpression::Resource(resource)); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 1858844bc05a6e12abbf07af83cad816590ddd03..e44415f60bff82fb92d0cf4ec81935564a2f083a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -168,6 +168,9 @@ class XlaOpKernelContext { // Returns an XlaExpression describing the value of 'index'. void SetOutputExpression(int index, const XlaExpression& expression); + // Sets output `index` to the Tensor List `handle`. + void SetTensorListOutput(int index, const xla::XlaOp& handle); + // Status handling. void SetStatus(const Status& status) { context_->SetStatus(status); } Status status() { return context_->status(); } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 0bdd4a1085445420a5147756daac4a54f4725f11..ce3b6b298c6dc5a08e7b794bbab3a28575967d28 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -47,13 +47,14 @@ extern const char* const DEVICE_XLA_GPU; constexpr std::array kFloatTypes = { {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}}; -constexpr std::array kNumericTypes = { +constexpr std::array kNumericTypes = { {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF, - DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}}; + DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BFLOAT16}}; -constexpr std::array kCpuAllTypes = { +constexpr std::array kCpuAllTypes = { {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, - DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; + DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128, DT_BOOL}}; constexpr std::array kGpuAllTypes = { {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 722d1376687efa1c04158e3fd9ce539aac9d0122..636e5ef721f58c009566c10a653d09a7667619c0 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -152,7 +152,7 @@ cc_library( ":status", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/stream_executor", + "//tensorflow/stream_executor/lib", ], ) @@ -717,6 +717,7 @@ cc_library( ":types", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], @@ -741,6 +742,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_evaluator", "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -824,6 +826,7 @@ cc_library( "debug_options_parsers.h", ], hdrs = ["debug_options_flags.h"], + visibility = [":friends"], deps = [ ":parse_flags_from_env", diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index 58cc1575858201b4508d7340cb47e59c4f4c5783..529e7f77cec43f3158fcb59a53efa9a085d7422a 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -272,6 +272,15 @@ class Array { std::iota(&values_[0], &values_[0] + num_elements(), value); } + // Fills the array with a repeating sequence: + // [value, value + 1, ..., value + length - 1, value, ... ] + void FillRepeatedIota(const T& value, int64 length) { + for (int64 i = 0; i < num_elements(); i += length) { + std::iota(&values_[i], &values_[std::min(i + length, num_elements())], + value); + } + } + // Fills the array with the sequence i*multiplier for i=0,1,... void FillWithMultiples(const T& multiplier) { for (int64 i = 0; i < num_elements(); ++i) { @@ -280,11 +289,11 @@ class Array { } // Fills the array with random normal variables with the specified mean. - void FillRandom(const T& value, const double mean = 0.0, + void FillRandom(const T& stddev, const double mean = 0.0, const int seed = 12345) { std::mt19937 g(seed); std::normal_distribution distribution(mean, - static_cast(value)); + static_cast(stddev)); for (int64 i = 0; i < num_elements(); ++i) { values_[i] = static_cast(distribution(g)); } diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 27c075e8f13f6777af4e837501d97a33034313f5..f5d56e8a9e1f3a05e1039f7cc90194407200f1ab 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -246,6 +246,7 @@ tf_cc_test( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 43127cae1e5d81521003a28288e27d291e33c9b9..4f020bcec2756a328755d86ab04154d54f532465 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -278,53 +278,51 @@ StatusOr> Client::Execute( const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options, ExecutionProfile* execution_profile) { - if (execution_options != nullptr && - execution_options->device_handles_size() > 1) { - std::vector computation_instances = { - XlaComputationInstance{ - computation, - std::vector(arguments.begin(), arguments.end()), - *execution_options, execution_profile}}; - TF_ASSIGN_OR_RETURN(auto results, ExecuteParallel(computation_instances)); - // The result selection is a bit hacky, but better than assuming it is - // device 0. - // - // TODO(b/118493728): Allow Execute to return one result per computation. - for (int64 i = 0; i < results.size(); i++) { - TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(*results[i])); - if (!ShapeUtil::IsEmptyTuple(shape)) { - VLOG(3) << "Fetching result from device " << i << ": " - << ShapeUtil::HumanString(shape); - return std::move(results[i]); - } + // Create an ExecutionOptions if necessary, or set its DeviceHandles. + absl::optional options_storage; + if (!execution_options || execution_options->device_handles().empty()) { + if (execution_options) { + options_storage.emplace(*execution_options); + } else { + options_storage.emplace(CreateDefaultExecutionOptions()); } - TF_RET_CHECK(!results.empty()); - VLOG(1) << "Defaulting to device 0 result"; - return std::move(results[0]); - } - - // The argument shapes affect how the computation is compiled. - std::vector arg_shapes(arguments.size()); - for (int i = 0; i < arguments.size(); i++) { - TF_ASSIGN_OR_RETURN(arg_shapes[i], GetShape(*arguments[i])); - } - - TF_ASSIGN_OR_RETURN(auto handle, - Compile(computation, arg_shapes, execution_options)); - - TF_ASSIGN_OR_RETURN(auto result, - Execute(handle, arguments, execution_profile)); - - if (execution_profile != nullptr) { - if (VLOG_IS_ON(1)) { - TF_ASSIGN_OR_RETURN( - auto execution_stats, - ExecutionStatsAsString(computation, *execution_profile)); - VLOG(1) << execution_stats; + execution_options = &*options_storage; + + TF_ASSIGN_OR_RETURN(auto device_handles, + GetDeviceHandles(/*device_count=*/1)); + TF_RET_CHECK(!device_handles.empty()); + *options_storage->add_device_handles() = std::move(device_handles[0]); + } + + std::vector computation_instances = { + XlaComputationInstance{ + computation, + std::vector(arguments.begin(), arguments.end()), + *execution_options, execution_profile}}; + + // Instead of invoking Compile() and Execute(), invoke + // Service::ExecuteParallel() to execute our one computation. Compile() + // caches the executable forever, which isn't what we want. + VLOG(1) << "Making ExecuteParallel request: " + << execution_options->DebugString(); + TF_ASSIGN_OR_RETURN(auto results, ExecuteParallel(computation_instances)); + VLOG(1) << "ExecuteParallel request done."; + + // The result selection is a bit hacky, but better than assuming it is + // device 0. + // + // TODO(b/118493728): Allow Execute to return one result per computation. + for (int64 i = 0; i < results.size(); i++) { + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(*results[i])); + if (!ShapeUtil::IsEmptyTuple(shape)) { + VLOG(3) << "Fetching result from device " << i << ": " + << ShapeUtil::HumanString(shape); + return std::move(results[i]); } } - - return std::move(result); + TF_RET_CHECK(!results.empty()); + VLOG(1) << "Defaulting to device 0 result"; + return std::move(results[0]); } StatusOr>> Client::ExecuteParallel( diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index d0ac4703c632e0e01d3c8911594b46fedf28930d..eff8713ac340e82ee7633f1f078334ba73b67b2f 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -52,6 +52,12 @@ class Client { // need to live beyond this call.) // * If execution_options.device_handles should be empty. If you need // non-empty device handles, call 'Execute' instead. + // + // TODO(b/122731460): This call caches the resulting Executable in the Service + // *forever*. If you're only going to run the computation once, you may want + // to call the Execute(const XlaComputation&) overload. If you're going to + // run the computation more than once but you want control over when the + // Executable is unloaded, use the LocalClient API. StatusOr Compile( const XlaComputation& computation, absl::Span argument_shapes, @@ -76,6 +82,10 @@ class Client { // device is chosen by the service. // * If execution_profile is not nullptr then the pointed-to ExecutionProfile // will be filled with profile data from the execution. + // + // TODO(b/122731460): The given computation is compiled and then thrown away + // immediately after it's run. If you want control over how long the + // resulting Executable lives, use the LocalClient API. StatusOr> Execute( const XlaComputation& computation, absl::Span arguments, diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 1f594e551af381d7537e947892cbf7e0b5b3b861..ec0e08975926f36c36c854f83a40b374b12a09a4 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -58,6 +58,12 @@ const Shape* ExecutableBuildOptions::result_layout() const { return result_layout_set_ ? &result_layout_ : nullptr; } +ExecutableBuildOptions& ExecutableBuildOptions::set_num_replicas( + int num_replicas) { + num_replicas_ = num_replicas; + return *this; +} + string ExecutableBuildOptions::ToString() const { string result_layout = "nullopt"; if (result_layout_set_) { @@ -65,8 +71,9 @@ string ExecutableBuildOptions::ToString() const { } return absl::StrFormat( "ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, " - "generate_hlo_graph=%s}", - device_ordinal_, result_layout, debug_options().xla_generate_hlo_graph()); + "generate_hlo_graph=%s, num_replicas=%d}", + device_ordinal_, result_layout, debug_options().xla_generate_hlo_graph(), + num_replicas_); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index a58090253bfac7779e4b61bc7231a0f0d945cc00..1d85fb34304b95d1fccdb0b0d6a7a65e739fae18 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -67,12 +67,18 @@ class ExecutableBuildOptions { // debugging. string ToString() const; + // The number of replicas of this computation that are to be executed. + // Defaults to 1. + int num_replicas() const { return num_replicas_; } + ExecutableBuildOptions& set_num_replicas(int num_replicas); + private: int device_ordinal_ = -1; Shape result_layout_; bool result_layout_set_ = false; absl::optional debug_options_; DeviceMemoryAllocator* device_allocator_ = nullptr; + int num_replicas_ = 1; }; } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 826b13fe3733b3334d2213eeb1d10cdd53d2f134..26c5e8eb73f0908cdc2d7df65936fadeda627423 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -34,6 +34,21 @@ cc_library( ], ) +xla_test( + name = "arithmetic_test", + srcs = ["arithmetic_test.cc"], + deps = [ + ":arithmetic", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "cholesky", srcs = ["cholesky.cc"], @@ -76,6 +91,39 @@ xla_test( ], ) +cc_library( + name = "comparators", + srcs = ["comparators.cc"], + hdrs = ["comparators.h"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +xla_test( + name = "comparators_test", + srcs = ["comparators_test.cc"], + deps = [ + ":comparators", + ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:inlined_vector", + ], +) + cc_library( name = "constants", srcs = ["constants.cc"], @@ -93,7 +141,6 @@ cc_library( xla_test( name = "constants_test", srcs = ["constants_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ ":constants", "//tensorflow/compiler/xla:test", @@ -147,7 +194,22 @@ cc_library( xla_test( name = "math_test", srcs = ["math_test.cc"], - tags = ["enable_for_xla_interpreter"], + deps = [ + ":math", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +xla_test( + name = "math_exhaustive_test", + srcs = ["math_exhaustive_test.cc"], + shard_count = 16, deps = [ ":math", "//tensorflow/compiler/xla:literal_util", @@ -168,12 +230,16 @@ cc_library( ":arithmetic", ":constants", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) @@ -181,16 +247,19 @@ cc_library( xla_test( name = "matrix_test", srcs = ["matrix_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ ":matrix", ":slicing", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", ], ) @@ -229,7 +298,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/core:lib", "@com_google_absl//absl/base", ], ) @@ -281,12 +349,7 @@ cc_library( srcs = ["slicing.cc"], hdrs = ["slicing.h"], deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "@com_google_absl//absl/types:span", ], @@ -295,13 +358,11 @@ cc_library( xla_test( name = "slicing_test", srcs = ["slicing_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ ":slicing", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -324,12 +385,10 @@ cc_library( xla_test( name = "sorting_test", srcs = ["sorting_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ ":sorting", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -352,7 +411,10 @@ cc_library( xla_test( name = "quantize_test", srcs = ["quantize_test.cc"], - tags = ["enable_for_xla_interpreter"], + # TODO(b/122119490): re-enable TAP after fixing. + tags = [ + "notap", + ], deps = [ ":quantize", "//tensorflow/compiler/xla:test", @@ -410,24 +472,23 @@ cc_library( xla_test( name = "triangular_solve_test", srcs = ["triangular_solve_test.cc"], - tags = ["noasan"], # sometimes times out, http://b/78650012 + tags = [ + "enable_for_xla_interpreter", + "noasan", # sometimes times out, http://b/78650012 + ], deps = [ + ":math", ":matrix", ":triangular_solve", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index e86c10f030f3990d67e5a6638100640f73c82307..3b875135af29f142463ffd783bfeaadc61ada1af 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -117,10 +117,70 @@ XlaOp Any(XlaOp predicates) { XlaComputation logical_or = CreateScalarOrComputation(PRED, builder); TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, builder->GetShape(predicates)); - std::vector all_dimensions(ShapeUtil::Rank(predicates_shape)); + std::vector all_dimensions(predicates_shape.rank()); std::iota(all_dimensions.begin(), all_dimensions.end(), 0); return Reduce(predicates, f, logical_or, all_dimensions); }); } +namespace { + +XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) { + XlaBuilder* builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); + XlaOp init_value; + XlaComputation reducer; + if (is_min) { + init_value = MaxValue(builder, input_shape.element_type()); + reducer = CreateScalarMinComputation(input_shape.element_type(), builder); + } else { + init_value = MinValue(builder, input_shape.element_type()); + reducer = CreateScalarMaxComputation(input_shape.element_type(), builder); + } + + XlaOp input_max = Reduce(input, init_value, reducer, + /*dimensions_to_reduce=*/{axis}); + std::vector broadcast_dims(input_shape.rank() - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); + std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); + // Compute a mask that has 1s for elements equal to the maximum. + XlaOp partial_mask = + ConvertElementType(Eq(input, input_max, broadcast_dims), output_type); + + // In order to make identity elements for a bitwise And, we: + // Left shift the 1 to the leftmost bit, yielding 0x10...0 + // Arithmetic right shift the 1 back to the rightmost bit, yielding + // 0xFF...F + int32 bits_in_type = + ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1; + XlaOp shift_amount = ConstantR0WithType(builder, output_type, bits_in_type); + XlaOp full_mask = ShiftRightArithmetic( + ShiftLeft(partial_mask, shift_amount), shift_amount); + + // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its + // index. + + const int64 axis_size = ShapeUtil::GetDimension(input_shape, axis); + XlaOp iota = Iota(builder, output_type, axis_size); + XlaOp product = And(full_mask, iota, /*broadcast_dimensions=*/{axis}); + + // If there are multiple maximum elements, choose the one with the highest + // index. + return Reduce(product, MinValue(builder, output_type), + CreateScalarMaxComputation(output_type, builder), + /*dimensions_to_reduce=*/{axis}); + }); +} + +} // namespace + +XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis) { + return ArgMinMax(input, output_type, axis, /*is_min=*/false); +} + +XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis) { + return ArgMinMax(input, output_type, axis, /*is_min=*/true); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index 632e8cc8bc64fad236a0226c6e93079aadde7050..d4a7812c441c351b121e5d72faf9642b06728b18 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -57,6 +57,14 @@ XlaComputation CreateScalarOrComputation(PrimitiveType type, // Note: if predicates is zero-sized, Any() vacuously returns false. XlaOp Any(XlaOp predicates); +// Returns the argmax of `input` along `axis`. `output_type` is the type to +// use for the output. +XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis); + +// Returns the argmin of `input` along `axis`. `output_type` is the type to +// use for the output. +XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_ diff --git a/tensorflow/compiler/xla/client/lib/arithmetic_test.cc b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a13839f9db89b9c07f2465867a503ef2193f8160 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +using ArithmeticTest = ClientLibraryTestBase; + +XLA_TEST_F(ArithmeticTest, ArgMinR2Axis0) { + XlaBuilder builder(TestName()); + auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); + ArgMin(x, S32, /*axis=*/0); + + std::vector expected = {0, 2, 2}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ArithmeticTest, ArgMinR2Axis1) { + XlaBuilder builder(TestName()); + auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); + ArgMin(x, S32, /*axis=*/1); + + std::vector expected = {0, 1, 2}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis0) { + XlaBuilder builder(TestName()); + auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); + ArgMax(x, S32, /*axis=*/0); + + std::vector expected = {2, 0, 1}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis1) { + XlaBuilder builder(TestName()); + auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); + ArgMax(x, S32, /*axis=*/1); + + std::vector expected = {1, 0, 0}; + ComputeAndCompareR1(&builder, expected, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/cholesky.cc b/tensorflow/compiler/xla/client/lib/cholesky.cc index fd98049968491d80b9717a2de1f34997bd9d18c1..414bd1494cd32f32a5c37e84119de930678a776b 100644 --- a/tensorflow/compiler/xla/client/lib/cholesky.cc +++ b/tensorflow/compiler/xla/client/lib/cholesky.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -54,7 +55,7 @@ XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int n_dims = ShapeUtil::Rank(a_shape); + const int n_dims = a_shape.rank(); const int64 n = ShapeUtil::GetDimension(a_shape, -1); auto major_dims = AsInt64Slice(a_shape.dimensions()) .subspan( @@ -67,29 +68,26 @@ XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) { auto body_fn = [&](XlaOp i, absl::Span loop_vars, XlaBuilder* body_builder) -> StatusOr> { - Shape col_shape; - Shape row_shape; - for (int64 d : major_dims) { - row_shape.add_dimensions(d); - col_shape.add_dimensions(d); - } - row_shape.add_dimensions(1); - row_shape.add_dimensions(n); - row_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_row = Zeros(body_builder, row_shape); - - col_shape.add_dimensions(n); - col_shape.add_dimensions(1); - col_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_col = Zeros(body_builder, col_shape); - - std::vector mask_vector(n); - std::iota(mask_vector.begin(), mask_vector.end(), 0); - auto mask_range = ConstantR1(body_builder, mask_vector); + std::vector row_shape_dims(major_dims.begin(), major_dims.end()); + std::vector col_shape_dims(major_dims.begin(), major_dims.end()); + row_shape_dims.push_back(1); + row_shape_dims.push_back(n); + auto mask_zeros_row = + Zeros(body_builder, + ShapeUtil::MakeShape(a_shape.element_type(), row_shape_dims)); + + col_shape_dims.push_back(n); + col_shape_dims.push_back(1); + auto mask_zeros_col = + Zeros(body_builder, + ShapeUtil::MakeShape(a_shape.element_type(), col_shape_dims)); + auto mask_range_row = - Broadcast(Reshape(mask_range, {0}, {1, n}), major_dims); + Iota(body_builder, ShapeUtil::MakeShape(S32, row_shape_dims), + /*iota_dimension=*/n_dims - 1); auto mask_range_col = - Broadcast(Reshape(mask_range, {0}, {n, 1}), major_dims); + Iota(body_builder, ShapeUtil::MakeShape(S32, col_shape_dims), + /*iota_dimension=*/n_dims - 2); auto body_a = loop_vars[0]; auto body_l = loop_vars[1]; @@ -144,7 +142,7 @@ XlaOp Cholesky(XlaOp a, int64 block_size, XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int ndims = ShapeUtil::Rank(a_shape); + const int ndims = a_shape.rank(); if (ndims < 2) { return InvalidArgument( "Argument to Cholesky must have rank >= 2; shape was %s", @@ -158,6 +156,12 @@ XlaOp Cholesky(XlaOp a, int64 block_size, ShapeUtil::HumanString(a_shape)); } + if (primitive_util::IsComplexType(a_shape.element_type())) { + return Unimplemented( + "Complex types are not implemented in Cholesky; got shape %s", + ShapeUtil::HumanString(a_shape)); + } + if (block_size < 1) { return InvalidArgument( "block_size argument to Cholesky must be >= 1; got %d", block_size); diff --git a/tensorflow/compiler/xla/client/lib/cholesky_test.cc b/tensorflow/compiler/xla/client/lib/cholesky_test.cc index ba9580a3d32225625acc1447344b7d2c16c5d8a5..095dd4fbf8b7c90047c4428b50c626c16e9c1e94 100644 --- a/tensorflow/compiler/xla/client/lib/cholesky_test.cc +++ b/tensorflow/compiler/xla/client/lib/cholesky_test.cc @@ -157,10 +157,10 @@ XLA_TEST_P(RandomCholeskyTest, Random) { xla::ErrorSpec(1e-4, 1e-4)); } -INSTANTIATE_TEST_CASE_P(RandomCholeskyTestInstance, RandomCholeskyTest, - ::testing::Values(CholeskyTestCase{1, 1}, - CholeskyTestCase{1, 2}, - CholeskyTestCase{10, 5}, - CholeskyTestCase{2, 20})); +INSTANTIATE_TEST_SUITE_P(RandomCholeskyTestInstance, RandomCholeskyTest, + ::testing::Values(CholeskyTestCase{1, 1}, + CholeskyTestCase{1, 2}, + CholeskyTestCase{10, 5}, + CholeskyTestCase{2, 20})); } // namespace diff --git a/tensorflow/compiler/xla/client/lib/comparators.cc b/tensorflow/compiler/xla/client/lib/comparators.cc new file mode 100644 index 0000000000000000000000000000000000000000..c620c9841a5146618e3a142adeb3fe2da525950a --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/comparators.cc @@ -0,0 +1,159 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/comparators.h" + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +using XlaOpGenerator = XlaOp (*)(const XlaOp&, const XlaOp&, + absl::Span); + +XlaOp BitcastConvertFloatingPointToIntegral(const XlaOp& value, + int64 bit_width) { + PrimitiveType signed_type; + PrimitiveType unsigned_type; + XlaOp max_value; + switch (bit_width) { + case 16: + max_value = + ConstantR0(value.builder(), + static_cast(std::numeric_limits::max())); + signed_type = S16; + unsigned_type = U16; + break; + case 32: + max_value = + ConstantR0(value.builder(), + static_cast(std::numeric_limits::max())); + signed_type = S32; + unsigned_type = U32; + break; + case 64: + max_value = + ConstantR0(value.builder(), + static_cast(std::numeric_limits::max())); + signed_type = S64; + unsigned_type = U64; + break; + default: + return value.builder()->ReportError( + InvalidArgument("Invalid bit width %lld for Comparator floating " + "point parameter.", + bit_width)); + } + // Switch from a floating point value to a integer value in such a way that + // when using the integer value to compare, we get the same result for normal + // values, and -Nan is treated as the smallest value, and Nan is treated as + // the largest value. + // If f is a float, and + // x = bit_cast(f); + // y = x < 0 ? numeric_limits::max() - x : x; + // then y is ordered as an int32 such that finite values have the obvious + // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning + // and end of the ordering. + // Note that in order to avoid -x to overflow, we calculate + // numeric_limits::max() - x as unsigned, and then convert back to + // signed. + auto signed_value = BitcastConvertType(value, signed_type); + auto unsigned_value = BitcastConvertType(value, unsigned_type); + auto flipped_value = + BitcastConvertType(Sub(max_value, unsigned_value), signed_type); + auto is_negative = Lt(signed_value, Zero(value.builder(), signed_type)); + return Select(is_negative, flipped_value, signed_value); +} + +XlaComputation CreateScalarComparisonComputation( + const string& name, const std::vector& operand_types, + XlaBuilder* builder, XlaOpGenerator generator) { + // Create a default computation where we compare only the first two + // parameters of type 'operand_types[0]'. + auto b = builder->CreateSubBuilder(name); + if (operand_types.empty()) { + b->ReportError(InvalidArgument("operand_types should not be empty")); + return b->BuildAndNoteError(); + } + + int64 parameter_count = 0; + XlaOp first_lhs_param; + XlaOp first_rhs_param; + + // For each type in 'operand_types' we create two parameters of this type. The + // idea is that this computation can be used by n-ary Sort, and potentially + // should support comparing also the other operands of sort. In this default + // computation, however, we will not actually use any parameters except the + // first two. + for (auto operand_type : operand_types) { + auto scalar_shape = ShapeUtil::MakeShape(operand_type, {}); + auto lhs_param = Parameter(b.get(), parameter_count * 2, scalar_shape, + absl::StrCat("p.", parameter_count, ".lhs")); + auto rhs_param = Parameter(b.get(), parameter_count * 2 + 1, scalar_shape, + absl::StrCat("p.", parameter_count, ".rhs")); + if (parameter_count == 0) { + first_lhs_param = lhs_param; + first_rhs_param = rhs_param; + } + ++parameter_count; + } + if (primitive_util::IsFloatingPointType(operand_types[0])) { + PrimitiveType compare_type = operand_types[0]; + // Special-case handling for BF16. We currently do not support direct + // comparisons with BF16, so we convert to F32 and then use the F32 + // comparison logic. + if (compare_type == BF16) { + compare_type = F32; + first_lhs_param = ConvertElementType(first_lhs_param, F32); + first_rhs_param = ConvertElementType(first_rhs_param, F32); + } + int64 bit_width = primitive_util::BitWidth(compare_type); + first_lhs_param = + BitcastConvertFloatingPointToIntegral(first_lhs_param, bit_width); + first_rhs_param = + BitcastConvertFloatingPointToIntegral(first_rhs_param, bit_width); + } + generator(first_lhs_param, first_rhs_param, {}); + return b->BuildAndNoteError(); +} +} // namespace + +// Creates a scalar less-than computation and returns it. +XlaComputation CreateScalarLtComputation( + const std::vector& operand_types, XlaBuilder* builder) { + return CreateScalarComparisonComputation("compare-less-than", operand_types, + builder, Lt); +} + +// Creates a scalar greater-than computation and returns it. +XlaComputation CreateScalarGtComputation( + const std::vector& operand_types, XlaBuilder* builder) { + return CreateScalarComparisonComputation("compare-greater-than", + operand_types, builder, Gt); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/comparators.h b/tensorflow/compiler/xla/client/lib/comparators.h new file mode 100644 index 0000000000000000000000000000000000000000..cbcfc227dd495537f59bf0a9090bad8ade15da62 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/comparators.h @@ -0,0 +1,47 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_COMPARATORS_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_COMPARATORS_H_ + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Creates a scalar less-than computation and returns it. The created +// computation has 2 * 'operand_types.size()' many parameters, where parameters +// 2 * i and 2 * i + 1 are a scalar with primitive type 'operand_types[i]'. The +// computation compares the first two parameters. For floating point types, a +// total order is created where +// -NaN < -infinity < ... < -0 < 0 < ... < infinity < NaN +XlaComputation CreateScalarLtComputation( + const std::vector& operand_types, XlaBuilder* builder); + +// Creates a scalar greater-than computation and returns it. The created +// computation has 2 * 'operand_types.size()' many parameters, where parameters +// 2 * i and 2 * i + 1 are a scalar with primitive type 'operand_types[i]'. The +// computation compares the first two parameters. For floating point types, a +// total order is created where +// NaN > infinity > ... > 0 > -0 > ... > -infinity > -NaN +XlaComputation CreateScalarGtComputation( + const std::vector& operand_types, XlaBuilder* builder); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_COMPARATORS_H_ diff --git a/tensorflow/compiler/xla/client/lib/comparators_test.cc b/tensorflow/compiler/xla/client/lib/comparators_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..598956803b34702b1e095a342648d348fa350b29 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/comparators_test.cc @@ -0,0 +1,149 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/comparators.h" + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class ComparatorsTest : public ClientLibraryTestBase { + public: + ComparatorsTest() : builder_(TestName()) {} + XlaBuilder* builder() { return &builder_; } + + private: + XlaBuilder builder_; +}; + +template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative::type> +void BuildComparatorAndComparisons(ComparatorsTest* test, + bool compare_less_than, + absl::InlinedVector* expected) { + auto compare = compare_less_than + ? CreateScalarLtComputation({type}, test->builder()) + : CreateScalarGtComputation({type}, test->builder()); + + auto negative_nan = ConstantR0( + test->builder(), -T(std::numeric_limits::quiet_NaN())); + auto positive_nan = ConstantR0(test->builder(), + T(std::numeric_limits::quiet_NaN())); + auto negative_zero = ConstantR0(test->builder(), T(-0.)); + auto positive_zero = ConstantR0(test->builder(), T(0.)); + auto negative_infinity = MinValue(test->builder(), type); + auto positive_infinity = MaxValue(test->builder(), type); + + // List the values in the expected sorting order from smallest to largest. + std::vector all_constants{negative_nan, negative_infinity, + negative_zero, positive_zero, + positive_infinity, positive_nan}; + + // Do pairwise comparisons. + std::vector all_comparisons; + for (const XlaOp& lhs_constant : all_constants) { + for (const XlaOp& rhs_constant : all_constants) { + all_comparisons.push_back(Broadcast( + Call(test->builder(), compare, {lhs_constant, rhs_constant}), {1})); + } + } + + // Concantenate the comparison results. + ConcatInDim(test->builder(), all_comparisons, 0); + + // If we use less-than comparisons, we expect the comparison to result in true + // if the lhs value to be compared appears earlier in 'all_constants' than the + // rhs value. Likewise, if we use greater-than comparisons, we expect the + // comparison to return true if the rhs value appears earlier in + // 'all_constants' than the lhs value. + expected->clear(); + for (int i = 0; i < all_constants.size(); ++i) { + for (int j = 0; j < all_constants.size(); ++j) { + expected->push_back(compare_less_than ? i < j : i > j); + } + } +} + +XLA_TEST_F(ComparatorsTest, CompareLtBF16) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/true, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareGtBF16) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/false, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareLtF16) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/true, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareGtF16) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/false, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareLtF32) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/true, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareGtF32) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/false, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareLtF64) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/true, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareGtF64) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/false, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h index 81624614c1e3599dfe116eb61d9e2edcd5230684..4e5310a380e8bda15348dae2cbb0ea9e2c381bcb 100644 --- a/tensorflow/compiler/xla/client/lib/constants.h +++ b/tensorflow/compiler/xla/client/lib/constants.h @@ -56,6 +56,8 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { return ConstantR0(builder, static_cast(value)); case C64: return ConstantR0(builder, static_cast(value)); + case C128: + return ConstantR0(builder, static_cast(value)); case U8: return ConstantR0(builder, static_cast(value)); case U32: @@ -88,6 +90,27 @@ XlaOp ScalarLike(XlaOp prototype, T value) { }); } +// Returns an array or scalar containing copies of `value` cast to the same +// run-type type as `prototype` and broadcast to the same dimensions as +// `prototype`. +// +// If `prototype` is not a scalar or array, returns an error. +template +XlaOp FullLike(XlaOp prototype, T value) { + XlaBuilder* builder = prototype.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); + if (ShapeUtil::IsScalar(shape) || shape.IsArray()) { + return Broadcast(ScalarLike(prototype, value), shape.dimensions()); + } else { + return InvalidArgument( + "Prototype shape for BroadcastConstantLike must be a scalar or " + "array, but was %s", + shape.ToString()); + } + }); +} + // Returns a scalar with value '0' of 'type'. XlaOp Zero(XlaBuilder* builder, PrimitiveType type); diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 36fdda39b4124b9100c6054160f9c17bdf787d6f..253b3440e200d04e76fb64b90c1707d8a21869e8 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -78,26 +79,79 @@ XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients) { } // Compute an approximation of the error function complement (1 - erf(x)). +// +// TODO(jlebar): This is not particularly efficient. The implementation in +// Cephes that this follows was written for double precision, but our +// coefficients are specified only to single-precision! Cephes has a different, +// simpler implementation for single-precision. +// +// Furthermore, we could simplify this further for f16 -- for example, because +// exp(-4.2 * 4.2) = 0 (f16), the computations in service of the x < 8.0 branch +// below are unnecessary. +// +// See also these alternate implementations of erf and erfc: +// +// https://stackoverflow.com/questions/35148198 +// https://stackoverflow.com/questions/35966695 +// XlaOp Erfc(XlaOp x) { - XlaOp abs_x = Abs(x); - XlaOp z = Exp(-x * x); - - XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient); - XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient); - XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient); - XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient); - - XlaOp y = Select(Lt(abs_x, ScalarLike(x, 8.0)), z * pp / pq, z * pr / ps); + auto& b = *x.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + // Reject non-real non-fp inputs. (We could extend erfc to accept complex + // types, but it doesn't seem necessary at this point.) + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x)); + if (!ShapeUtil::ElementIsFloating(shape)) { + return InvalidArgument( + "erfc only accepts real floating-point arrays or scalars, but got %s", + shape.ToString()); + } + XlaOp abs_x = Abs(x); + XlaOp z = Exp(-x * x); + + XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient); + XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient); + XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient); + XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient); + + XlaOp abs_x_small = Lt(abs_x, ScalarLike(x, 8.0)); + XlaOp y = Select(abs_x_small, z * pp / pq, z * pr / ps); + XlaOp result_no_underflow = + Select(Lt(x, ScalarLike(x, 0.0)), ScalarLike(x, 2.0) - y, y); + + // Check for edge cases, namely, exp(-x^2) is exactly 0, or the appropriate + // denominator (ps or pq) is inf. (The check for exp(-x^2) == 0 is + // necessary only for x == +/- inf, where this check lets us avoid + // multiplying 0 by inf and getting nan.) + auto is_pos_inf = [](XlaOp op) { + return And(Not(IsFinite(op)), Gt(op, ScalarLike(op, 0))); + }; + XlaOp underflow = + Or(Eq(z, ScalarLike(z, 0)), Or(And(is_pos_inf(pq), abs_x_small), + And(is_pos_inf(ps), Not(abs_x_small)))); + XlaOp result_underflow = + Select(Lt(x, ScalarLike(x, 0)), FullLike(x, 2), FullLike(x, 0)); - return Select(Lt(x, ScalarLike(x, 0.0)), ScalarLike(x, 2.0) - y, y); + return Select(underflow, result_underflow, result_no_underflow); + }); } // Compute a polynomial approximation of the error function. XlaOp Erf(XlaOp x) { - XlaOp z = x * x; - XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient); - XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient); - return x * pt / pu; + auto& b = *x.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + // Reject non-real non-fp inputs. (We could extend erf to accept complex + // types, but it doesn't seem necessary at this point.) + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x)); + if (!ShapeUtil::ElementIsFloating(shape)) { + return InvalidArgument( + "erf only accepts real floating-point arrays or scalars, but got %s", + shape.ToString()); + } + XlaOp z = x * x; + XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient); + XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient); + return x * pt / pu; + }); } // Approximation for the inverse error function from @@ -113,37 +167,30 @@ XlaOp Erf(XlaOp x) { // } // return p*x XlaOp ErfInv(XlaOp x) { - XlaBuilder* b = x.builder(); - return b->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x)); - constexpr int kDegree = 9; - constexpr std::array w_less_than_5_constants = { - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - constexpr std::array w_greater_than_5_constants = { - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; + constexpr int kDegree = 9; + constexpr std::array w_less_than_5_constants = { + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + constexpr std::array w_greater_than_5_constants = { + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; - auto one = ScalarLike(x, 1.0); - auto w = -Log((one - x) * (one + x)); - - auto lt = Lt(w, ScalarLike(x, 5.0)); - auto coefficient = [&](int i) { - return Select(lt, - Broadcast(ScalarLike(x, w_less_than_5_constants[i]), - AsInt64Slice(shape.dimensions())), - Broadcast(ScalarLike(x, w_greater_than_5_constants[i]), - AsInt64Slice(shape.dimensions()))); - }; - w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0)); - auto p = coefficient(0); - for (int i = 1; i < kDegree; ++i) { - p = coefficient(i) + p * w; - } - return p * x; - }); + auto one = ScalarLike(x, 1.0); + auto w = -Log((one - x) * (one + x)); + + auto lt = Lt(w, ScalarLike(x, 5.0)); + auto coefficient = [&](int i) { + return Select(lt, FullLike(x, w_less_than_5_constants[i]), + FullLike(x, w_greater_than_5_constants[i])); + }; + w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0)); + auto p = coefficient(0); + for (int i = 1; i < kDegree; ++i) { + p = coefficient(i) + p * w; + } + return p * x; } namespace { @@ -170,49 +217,94 @@ static constexpr std::array kLanczosCoefficients = { // t(z) = z + kLanczosGamma + 1/2 // A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) XlaOp Lgamma(XlaOp input) { - XlaOp one_half = ScalarLike(input, 0.5); - XlaOp one = ScalarLike(input, 1); + auto& b = *input.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + // Reject non-real non-fp inputs. (We could extend lgamma to accept complex + // types, but it doesn't seem necessary at this point.) + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(input)); + if (!ShapeUtil::ElementIsFloating(shape)) { + return InvalidArgument( + "lgamma only accepts real floating-point arrays or scalars, but got " + "%s", + shape.ToString()); + } - XlaOp pi = ScalarLike(input, M_PI); - XlaOp log_pi = ScalarLike(input, std::log(M_PI)); - XlaOp log_sqrt_two_pi = ScalarLike(input, (std::log(2) + std::log(M_PI)) / 2); + XlaOp one_half = ScalarLike(input, 0.5); + XlaOp one = ScalarLike(input, 1); - XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); - XlaOp log_lanczos_gamma_plus_one_half = - ScalarLike(input, std::log(kLanczosGamma + 0.5)); + XlaOp pi = ScalarLike(input, M_PI); + XlaOp log_pi = ScalarLike(input, std::log(M_PI)); + XlaOp log_sqrt_two_pi = + ScalarLike(input, (std::log(2) + std::log(M_PI)) / 2); - XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); + XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); + XlaOp log_lanczos_gamma_plus_one_half = + ScalarLike(input, std::log(kLanczosGamma + 0.5)); - // If the input is less than 0.5 use Gauss's reflection formula: - // gamma(x) = pi / sin(pi * x) * gamma(1 - x) - XlaOp need_to_reflect = Lt(Real(input), one_half); - XlaOp z = Select(need_to_reflect, -input, input - one); + XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); - XlaOp x = base_lanczos_coeff; - for (int i = 0; i < kLanczosCoefficients.size(); ++i) { - XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); - XlaOp index = ScalarLike(input, i); - x = x + lanczos_coefficient / (z + index + one); - } + // If the input is less than 0.5 use Euler's reflection formula: + // gamma(x) = pi / (sin(pi * x) * gamma(1 - x)) + XlaOp need_to_reflect = Lt(input, one_half); + XlaOp z = Select(need_to_reflect, -input, input - one); + + XlaOp x = base_lanczos_coeff; + for (int i = 0; i < kLanczosCoefficients.size(); ++i) { + XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); + XlaOp index = ScalarLike(input, i); + x = x + lanczos_coefficient / (z + index + one); + } - // To improve accuracy on platforms with less-precise log implementations, - // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on - // the device. - // log(t) = log(kLanczosGamma + 0.5 + z) - // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) - XlaOp t = lanczos_gamma_plus_one_half + z; - XlaOp log_t = - log_lanczos_gamma_plus_one_half + Log1p(z / lanczos_gamma_plus_one_half); - - XlaOp log_y = log_sqrt_two_pi + (z + one_half) * log_t - t + Log(x); - - // If z = a + 0j, the analytic continuation of log reduces to taking the - // absolute value of the real part. - // Re(log(z)) = Re(log|z| + arg(z)j) - // = log|a| - XlaOp reflection = log_pi - Log(Abs(Sin(pi * input))) - log_y; - XlaOp result = Select(need_to_reflect, reflection, log_y); - return result; + // To improve accuracy on platforms with less-precise log implementations, + // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on + // the device. + // log(t) = log(kLanczosGamma + 0.5 + z) + // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) + XlaOp t = lanczos_gamma_plus_one_half + z; + XlaOp log_t = log_lanczos_gamma_plus_one_half + + Log1p(z / lanczos_gamma_plus_one_half); + + XlaOp log_y = log_sqrt_two_pi + (z + one_half) * log_t - t + Log(x); + + // Compute the reflected value, used when x < 0.5: + // + // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))). + // + // (The abs is because lgamma is the log of the absolute value of the gamma + // function.) + // + // We have to be careful when computing the final term above. gamma(x) goes + // to +/-inf at every integer x < 0, and this is controlled by the + // sin(pi * x) term. The slope is large, so precision is particularly + // important. + // + // Because abs(sin(pi * x)) has period 1, we can equivalently use + // abs(sin(pi * frac(x))) = sin(pi * frac(x)), where frac(x) is the + // fractional part of x. This is more numerically accurate: It doesn't + // overflow to inf like pi * x can, and if x is an integer, it evaluates to + // 0 exactly, which is significant because we then take the log of this + // value, and log(0) is inf. + // + // We don't have a frac(x) primitive in XLA and computing it is tricky, but + // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for + // our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)). + // + XlaOp abs_input = Abs(input); + XlaOp reflection_denom = Log(Sin(pi * (abs_input - Floor(abs_input)))); + + // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf, + // then it "wins" and the result is +/-inf. + XlaOp reflection = + Select(IsFinite(reflection_denom), log_pi - reflection_denom - log_y, + -reflection_denom); + XlaOp result = Select(need_to_reflect, reflection, log_y); + + // lgamma(+/-inf) = +inf. + XlaOp inf_bcast = FullLike(input, std::numeric_limits::infinity()); + return Select(Or(IsFinite(input), // is finite, or + Not(Or(Lt(input, one), Ge(input, one)))), // is nan + result, inf_bcast); + }); } // Compute the Digamma function using Lanczos' approximation from "A Precision @@ -223,69 +315,96 @@ XlaOp Lgamma(XlaOp input) { // A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) // A'(z) = sigma(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) XlaOp Digamma(XlaOp input) { - XlaOp zero = ScalarLike(input, 0); - XlaOp one_half = ScalarLike(input, 0.5); - XlaOp one = ScalarLike(input, 1); - - XlaOp pi = ScalarLike(input, M_PI); - - XlaOp lanczos_gamma = ScalarLike(input, kLanczosGamma); - XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); - XlaOp log_lanczos_gamma_plus_one_half = - ScalarLike(input, std::log(kLanczosGamma + 0.5)); - - XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); - - // If the input is less than 0.5 use Gauss's reflection formula: - // digamma(x) = digamma(1 - x) - pi * cot(pi * x) - XlaOp need_to_reflect = Lt(Real(input), one_half); - XlaOp z = Select(need_to_reflect, -input, input - one); - - XlaOp num = zero; - XlaOp denom = base_lanczos_coeff; - for (int i = 0; i < kLanczosCoefficients.size(); ++i) { - XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); - XlaOp index = ScalarLike(input, i); - num = num - lanczos_coefficient / ((z + index + one) * (z + index + one)); - denom = denom + lanczos_coefficient / (z + index + one); - } + auto& b = *input.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + // Reject non-real non-fp inputs. (We could extend digamma to accept + // complex types, but it doesn't seem necessary at this point.) + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(input)); + if (!ShapeUtil::ElementIsFloating(shape)) { + return InvalidArgument( + "digamma only accepts real floating-point arrays or scalars, but got " + "%s", + shape.ToString()); + } + + XlaOp zero = ScalarLike(input, 0); + XlaOp one_half = ScalarLike(input, 0.5); + XlaOp one = ScalarLike(input, 1); + + XlaOp pi = ScalarLike(input, M_PI); + + XlaOp lanczos_gamma = ScalarLike(input, kLanczosGamma); + XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); + XlaOp log_lanczos_gamma_plus_one_half = + ScalarLike(input, std::log(kLanczosGamma + 0.5)); + + XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); + + // If the input is less than 0.5 use Euler's reflection formula: + // digamma(x) = digamma(1 - x) - pi * cot(pi * x) + XlaOp need_to_reflect = Lt(input, one_half); + XlaOp z = Select(need_to_reflect, -input, input - one); + + XlaOp num = zero; + XlaOp denom = base_lanczos_coeff; + for (int i = 0; i < kLanczosCoefficients.size(); ++i) { + XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); + XlaOp index = ScalarLike(input, i); + num = num - lanczos_coefficient / ((z + index + one) * (z + index + one)); + denom = denom + lanczos_coefficient / (z + index + one); + } - // To improve accuracy on platforms with less-precise log implementations, - // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on - // the device. - // log(t) = log(kLanczosGamma + 0.5 + z) - // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) - XlaOp t = lanczos_gamma_plus_one_half + z; - XlaOp log_t = - log_lanczos_gamma_plus_one_half + Log1p(z / lanczos_gamma_plus_one_half); - - XlaOp y = log_t + num / denom - lanczos_gamma / t; - XlaOp reflection = y - pi * Cos(pi * input) / Sin(pi * input); - XlaOp result = Select(need_to_reflect, reflection, y); - return result; + // To improve accuracy on platforms with less-precise log implementations, + // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on + // the device. + // log(t) = log(kLanczosGamma + 0.5 + z) + // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) + XlaOp t = lanczos_gamma_plus_one_half + z; + XlaOp log_t = log_lanczos_gamma_plus_one_half + + Log1p(z / lanczos_gamma_plus_one_half); + + XlaOp y = log_t + num / denom - lanczos_gamma / t; + XlaOp reflection = y - pi * Cos(pi * input) / Sin(pi * input); + return Select(need_to_reflect, reflection, y); + }); } // Implements Banker's rounding: numbers that are equidistant between two // integers are rounded towards even. XlaOp RoundToEven(XlaOp x) { - auto half = ScalarLike(x, 0.5); - auto one = ScalarLike(x, 1.0); - auto two = ScalarLike(x, 2.0); - - auto round_val = Floor(x); - auto fraction = x - round_val; - auto nearest_even_int = round_val - two * Floor(half * x); - auto is_odd = Eq(nearest_even_int, one); - return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)), - round_val + one, round_val); + auto& b = *x.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + // Reject non-real non-fp inputs (What does it even mean to round a complex + // number? Do you round each component equally? In that case, you should + // just ask for that explicitly.) + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x)); + if (ShapeUtil::ElementIsComplex(shape)) { + return InvalidArgument( + "RoundToEven doesn't accept complex inputs, but got %s", + shape.ToString()); + } + auto half = ScalarLike(x, 0.5); + auto one = ScalarLike(x, 1.0); + auto two = ScalarLike(x, 2.0); + + auto round_val = Floor(x); + auto fraction = x - round_val; + auto nearest_even_int = round_val - two * Floor(half * x); + auto is_odd = Eq(nearest_even_int, one); + return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)), + round_val + one, round_val); + }); } // Trigonometric functions. -// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) +// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1 +// pi if x == -1 XlaOp Acos(XlaOp x) { - return ScalarLike(x, 2.0) * - Atan2(Sqrt(ScalarLike(x, 1.0) - x * x), ScalarLike(x, 1.0) + x); + return Select(Ne(x, FullLike(x, -1)), + ScalarLike(x, 2.0) * Atan2(Sqrt(ScalarLike(x, 1.0) - x * x), + ScalarLike(x, 1.0) + x), + FullLike(x, M_PI)); } // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) @@ -323,9 +442,88 @@ XlaOp MaybeConjugate(XlaOp x, bool conjugate) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - auto perform_conj = shape.element_type() == C64 && conjugate; + auto perform_conj = + primitive_util::IsComplexType(shape.element_type()) && conjugate; return perform_conj ? Conj(x) : x; }); } +XlaOp NextAfter(XlaOp from, XlaOp to) { + auto builder = from.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(from)); + int bitwidth = primitive_util::BitWidth(shape.element_type()); + auto int_type = primitive_util::UnsignedIntegralTypeForBitWidth(bitwidth); + auto from_as_int = BitcastConvertType(from, int_type); + auto to_as_int = BitcastConvertType(to, int_type); + + // The result is NaN if either "from" or "to" are NaN. + auto from_is_nan = Ne(from, from); + auto to_is_nan = Ne(to, to); + auto nan_input = Or(from_is_nan, to_is_nan); + auto result_for_nan = + Broadcast(ScalarLike(from, std::numeric_limits::quiet_NaN()), + shape.dimensions()); + result_for_nan = BitcastConvertType(result_for_nan, int_type); + + // The sign bit is the MSB. + const int64 sign_mask = int64{1} << (bitwidth - 1); + // Discard the sign bit to make the result non-negative. + auto from_abs = And(from_as_int, ScalarLike(from_as_int, ~sign_mask)); + auto to_abs = And(to_as_int, ScalarLike(to_as_int, ~sign_mask)); + + // When both "from" and "to" are equal, the result is "to". + // N.B. It would not make a difference if we chose the result to be "from". + auto from_and_to_are_equal = Eq(from_as_int, to_as_int); + auto result_for_equal = to_as_int; + + // When both "from" and "to" are both 0, the result is "to". This ensures we + // get a zero signed like "to". + auto from_is_zero = Eq(from_abs, ZerosLike(from_abs)); + auto to_is_zero = Eq(to_abs, ZerosLike(to_abs)); + auto result_for_both_zero = to_as_int; + + auto from_sign = And(from_as_int, ScalarLike(from_as_int, sign_mask)); + auto to_sign = And(to_as_int, ScalarLike(to_as_int, sign_mask)); + + // If from == 0 && to != 0, we need to return the smallest subnormal number + // signed like "to". + auto result_for_from_zero_to_non_zero = + Or(to_sign, ScalarLike(from_as_int, 1)); + + // If the sign of "from" and "to" disagree: + // - we need to make the magnitude of "from" smaller so that it is closer to + // zero. + // + // Otherwise the signs agree: + // - "from" with a magnitude larger than "to" means we need to make the + // magnitude smaller. + // - "from" with a magnitude smaller than "to" means we need to make the + // magnitude larger. + // - "from" with the same magnitude and sign as "to" has already been + // handled. + auto signs_disagree = Ne(from_sign, to_sign); + auto from_magnitude_larger_than_to = Gt(from_abs, to_abs); + auto result_has_smaller_magnitude = + Or(from_magnitude_larger_than_to, signs_disagree); + auto magnitude_adjustment = + Select(result_has_smaller_magnitude, + Broadcast(ScalarLike(from_as_int, -1), shape.dimensions()), + Broadcast(ScalarLike(from_as_int, 1), shape.dimensions())); + auto result = Add(from_as_int, magnitude_adjustment); + // Handle from == ±0. + result = Select(from_is_zero, + Select(to_is_zero, result_for_both_zero, + result_for_from_zero_to_non_zero), + result); + // Handle from == to. + result = Select(from_and_to_are_equal, result_for_equal, result); + // Handle isnan(from) || isnan(to). + result = Select(nan_input, result_for_nan, result); + + // Cast back to the original type. + return BitcastConvertType(result, shape.element_type()); + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index 17612bf9fdc0f1eabb338671c93c025c5b268872..583481c7f329fec9b7c5262e820b6796654cb7a2 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -32,7 +32,7 @@ XlaOp Square(XlaOp operand); // Computes the reciprocal of 'operand'. XlaOp Reciprocal(XlaOp operand); -// Evaluates a polynomial given coefficients and `x`. +// Evaluates a polynomial given coefficients and 'x'. // N.B. Coefficients should be supplied in decreasing order. XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients); @@ -86,10 +86,14 @@ XlaOp Cosh(XlaOp x); // Computes the hyperbolic sine of 'x'. XlaOp Sinh(XlaOp x); -// Applies a complex conjugation operation if `a` is complex and `conjugate` +// Applies a complex conjugation operation if 'a' is complex and 'conjugate' // is true, otherwise returns its argument. xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); +// Returns the next number after 'from' in the direction of 'to' the same way +// std::nextafter(from, to) would. +XlaOp NextAfter(XlaOp from, XlaOp to); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ diff --git a/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc b/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0fb13a73b4fdce1fd92a95030135c51e13e43653 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc @@ -0,0 +1,186 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace { + +using Eigen::half; + +struct Testcase { + Testcase(string name, XlaOp (*op)(XlaOp), float (*host_op)(float)) + : name(name), op(op), host_op(host_op) {} + + Testcase& set_tolerance(float abs_err, float rel_err) { + error.abs = abs_err; + error.rel = rel_err; + return *this; + } + + Testcase& set_relaxed_nans() { + error.relaxed_nans = true; + return *this; + } + + Testcase& set_fewer_infs_ok() { + error.fewer_infs_ok = true; + return *this; + } + + Testcase& set_skip_pos_inf() { + skip_pos_inf = true; + return *this; + } + + Testcase& set_skip_neg_inf() { + skip_neg_inf = true; + return *this; + } + + Testcase& set_skip_infs() { + skip_pos_inf = true; + skip_neg_inf = true; + return *this; + } + + Testcase& set_skip_neg_zero() { + skip_neg_zero = true; + return *this; + } + + string name; + XlaOp (*op)(XlaOp); + float (*host_op)(float); + + ErrorSpec error{0.01}; + + // If true, don't test +/-infinity or negative 0. + bool skip_pos_inf = false; + bool skip_neg_inf = false; + bool skip_neg_zero = false; +}; + +void PrintTo(const Testcase& tc, std::ostream* os) { *os << tc.name; } + +class MathExhaustiveTest : public ClientLibraryTestBase, + public ::testing::WithParamInterface { + public: + MathExhaustiveTest() { + // Disable fast-math, otherwise we get the wrong results for e.g. + // sqrt(-inf). + SetFastMathDisabled(true); + } +}; + +// Checks a function's behavior on all fp16 values. +// +// TODO(jlebar): asin and lgamma tests fail on interpreter. +XLA_TEST_P(MathExhaustiveTest, DISABLED_ON_INTERPRETER(F16)) { + const Testcase& tc = GetParam(); + XlaBuilder b(TestName()); + + std::vector input; + for (uint32 i = 0; i < 1 << 16; ++i) { + half h; + h.x = i; + + // If we're not using infinity as an input, use 0 as a placeholder rather + // than simply skipping this element. We do this because when the test + // framework reports an incorrect answer, it tells us which index failed. + // So long as our inputs are a simple list of all possible float16s, we can + // convert an index to a half with e.g. the following Python: + // + // np.frombuffer(array('H', [12345]), dtype=np.float16)[0] + // + // but as soon as our list of inputs has any gaps, this doesn't work. + if (std::isinf(static_cast(h)) && + ((tc.skip_pos_inf && h > half{0}) || + (tc.skip_neg_inf && h < half{0}))) { + h = half{0}; + } + + if (h == half{0} && tc.skip_neg_zero && + std::signbit(static_cast(h))) { + h = half{0}; + } + + input.push_back(h); + } + + std::vector expected_result; + for (const auto& h : input) { + expected_result.push_back( + static_cast(tc.host_op(static_cast(h)))); + } + + XlaOp param = AddParam(LiteralUtil::CreateR1(input), &b); + tc.op(param); + ComputeAndCompareR1(&b, expected_result, {}, tc.error); +} + +// TODO(b/123355973): The following tests from math.cc are missing. +// +// - Many failures. +// +// Testcase{"acosh", Acosh, std::acosh}.set_relaxed_nans(), +// Testcase{"asinh", Asinh, std::asinh}, +// Testcase{"sinh", Sinh, std::sinh}, +// Testcase{"cosh", Cosh, std::cosh}.set_fewer_infs_ok(), +// Testcase{"erf", Erf, std::erf}, +// Testcase{"round_to_even", RoundToEven, +// [](float x) { return std::nearbyint(x / 2) * 2; }}, +// +// - No equivalent std function to compare with. +// +// Testcase{"erfinv", ErfInv, std::erfinv}, +// Testcase{"digamma", Digamma, std::digamma}, +// +// - Needs a special test (function takes two args, and simply computing in f32 +// and downcasting to f16 doesn't give the correct answer). +// +// Testcase{"nextafter", NextAfter, std::nextafter}, +// +// TODO(b/123355973): Test math functions not from math.cc (e.g. log). +// TODO(b/123355973): Test bf16 and f32. +// +INSTANTIATE_TEST_CASE_P( + MathExhaustiveTest_Instantiation, MathExhaustiveTest, + ::testing::ValuesIn(std::vector{ + Testcase{"sqrt", Sqrt, std::sqrt}.set_skip_neg_inf(), + Testcase{"rsqrt", Rsqrt, [](float x) { return 1 / std::sqrt(x); }} + .set_tolerance(0.05, 0.05) + .set_skip_infs() + .set_skip_neg_zero(), + Testcase{"square", Square, [](float x) { return x * x; }}, + Testcase{"reciprocal", Reciprocal, [](float x) { return 1 / x; }}, + Testcase{"erfc", Erfc, std::erfc}, + Testcase{"lgamma", Lgamma, std::lgamma} + .set_tolerance(0.1, 0.15) + .set_fewer_infs_ok(), + Testcase{"asin", Asin, std::asin}.set_skip_infs(), + Testcase{"acos", Acos, std::acos}.set_skip_infs(), + Testcase{"atan", Atan, std::atan}, + Testcase{"tan", Tan, std::tan}.set_tolerance(0.05, 0.05), + })); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index ae2ea225d1aadd7b3a794eabeca866c498f34760..c2e1251fc2fa09956b9b60d4e3e13a5d0cb61d2b 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -30,6 +30,45 @@ class MathTest : public ClientLibraryTestBase { ErrorSpec error_spec_{0.0001}; }; +// Write TYPED_TESTs within the class definition so that we don't have to litter +// "this->" everywhere. +template +class MathTypedTest : public MathTest { + public: + void TestLogEdgeCases() { + SetFastMathDisabled(true); + + XlaBuilder b(TestName()); + Log(AddParam(LiteralUtil::CreateR1({T{0.0}, T{-0.0}}), &b)); + ComputeAndCompareR1(&b, + {-std::numeric_limits::infinity(), + -std::numeric_limits::infinity()}, + {}, error_spec_); + } + + void TestLog1pEdgeCases() { + SetFastMathDisabled(true); + + XlaBuilder b(TestName()); + Log1p(AddParam(LiteralUtil::CreateR1({T{0.0}, T{-0.0}, T{-1.0}}), &b)); + ComputeAndCompareR1( + &b, {T{0.0}, T{-0.0}, -std::numeric_limits::infinity()}, {}, + error_spec_); + } +}; + +// TODO(b/123355973): Add bfloat16 to TestTypes once it's working. +#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +using TestTypes = ::testing::Types; +#else +using TestTypes = ::testing::Types; +#endif + +TYPED_TEST_CASE(MathTypedTest, TestTypes); + +XLA_TYPED_TEST(MathTypedTest, LogEdgeCases) { this->TestLogEdgeCases(); } +XLA_TYPED_TEST(MathTypedTest, Log1pEdgeCases) { this->TestLog1pEdgeCases(); } + XLA_TEST_F(MathTest, SqrtF32) { XlaBuilder builder(TestName()); Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F32); @@ -106,6 +145,28 @@ XLA_TEST_F(MathTest, Lgamma) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +// TODO(jlebar): Fails on interpreter due to unimplemented operation. +XLA_TEST_F(MathTest, DISABLED_ON_INTERPRETER(LgammaF16)) { + SetFastMathDisabled(true); + + XlaBuilder b(TestName()); + + // These seemingly arbitrary inputs came from debugging the lgamma + // implementation against a test which tried all possible f16 values. + auto x = ConstantR1(&b, { + half(-7360.0), + half(-4066.0), + half(-5.9605e-08), + }); + Lgamma(x); + std::vector expected = { + std::numeric_limits::infinity(), + std::numeric_limits::infinity(), + half(16.64), + }; + ComputeAndCompareR1(&b, expected, {}, ErrorSpec{0.1}); +} + XLA_TEST_F(MathTest, Digamma) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {1.0, 0.5, 1 / 3.0, 0.25, 1 / 6.0, 0.125, @@ -148,5 +209,40 @@ XLA_TEST_F(MathTest, RoundToEven) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +XLA_TEST_F(MathTest, ErfRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + Erf(x); + EXPECT_FALSE(b.Build().status().ok()); +} + +XLA_TEST_F(MathTest, ErfcRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + Erfc(x); + EXPECT_FALSE(b.Build().status().ok()); +} + +XLA_TEST_F(MathTest, LgammaRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + Lgamma(x); + EXPECT_FALSE(b.Build().status().ok()); +} + +XLA_TEST_F(MathTest, DigammaRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + Digamma(x); + EXPECT_FALSE(b.Build().status().ok()); +} + +XLA_TEST_F(MathTest, RoundToEvenRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + RoundToEven(x); + EXPECT_FALSE(b.Build().status().ok()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc index ffd744d190885b8e3f4149a48a706498b3787618..a5aea96090c59c78d20cfc10a4bd6b312be592c1 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.cc +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -15,24 +15,32 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/matrix.h" +#include #include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n) { - auto a = Iota(builder, type, m); - auto b = Iota(builder, type, n); + auto a = Iota(builder, U32, m); + auto b = Iota(builder, U32, n); auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0}); return ConvertElementType(indicator, type); } @@ -41,7 +49,7 @@ XlaOp GetMatrixDiagonal(XlaOp x) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); TF_RET_CHECK(n_dims >= 2); const int64 m = shape.dimensions(n_dims - 2); const int64 n = shape.dimensions(n_dims - 1); @@ -64,105 +72,251 @@ XlaOp GetMatrixDiagonal(XlaOp x) { }); } -XlaOp Triangle(XlaOp x, bool lower) { +XlaOp TriangleMask(XlaOp x, int diagonal) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); TF_RET_CHECK(n_dims >= 2); const int64 m = shape.dimensions(n_dims - 2); const int64 n = shape.dimensions(n_dims - 1); absl::Span major_dims = AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); - auto a = Iota(builder, U32, n); - auto b = Iota(builder, U32, m); + auto a = Iota(builder, S32, n); + auto b = Iota(builder, S32, m) + ConstantR0(builder, diagonal); XlaOp indicator; - if (lower) { - indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - } else { - indicator = Le(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - } - auto mask = Broadcast(indicator, major_dims); - - return Select(mask, x, Zeros(builder, shape)); + indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + return Broadcast(indicator, major_dims); }); } +XlaOp Triangle(XlaOp x, bool lower) { + return lower ? Select(TriangleMask(x, 0), x, ZerosLike(x)) + : Select(TriangleMask(x, -1), ZerosLike(x), x); +} + XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } -XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) { +Status ValidateEinsumNumericDimensions(absl::Span x_config, + absl::Span y_config, + absl::Span output_config) { + for (auto dim : output_config) { + if (absl::c_linear_search(x_config, dim) || + absl::c_linear_search(y_config, dim)) { + if (absl::c_count(output_config, dim) > 1) { + return InvalidArgument("Einsum has repeated output dimension."); + } + continue; + } + return InvalidArgument( + "Einsum has output dimension without corresponding input dimension."); + } + for (auto dim : x_config) { + if (absl::c_linear_search(y_config, dim) || + absl::c_linear_search(output_config, dim)) { + if (absl::c_count(x_config, dim) > 1) { + return InvalidArgument("Einsum has repeated lhs dimension."); + } + continue; + } + return InvalidArgument( + "Einsum has lhs dimension without corresponding rhs or output " + "dimension."); + } + for (auto dim : y_config) { + if (absl::c_linear_search(x_config, dim) || + absl::c_linear_search(output_config, dim)) { + if (absl::c_count(y_config, dim) > 1) { + return InvalidArgument("Einsum has repeated rhs dimension."); + } + continue; + } + return InvalidArgument( + "Einsum has rhs dimension without corresponding lhs or output " + "dimension."); + } + return Status::OK(); +} + +xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, + absl::Span y_config, + absl::Span output_config, + xla::PrecisionConfig::Precision precision) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); - TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y)); + TF_RETURN_IF_ERROR( + ValidateEinsumNumericDimensions(x_config, y_config, output_config)); + const int64 x_rank = x_config.size(); + const int64 y_rank = y_config.size(); + const int64 output_rank = output_config.size(); + absl::flat_hash_set x_map; + absl::flat_hash_set y_map; + absl::flat_hash_set output_map; - // Check that both tensors have the same number of dimensions. There must be - // at least two (the batch dimensions can be empty). - if (ShapeUtil::Rank(x_shape) != ShapeUtil::Rank(y_shape)) { - return InvalidArgument( - "Arguments to BatchDot have different ranks: %s vs. %s", - ShapeUtil::HumanString(x_shape), ShapeUtil::HumanString(y_shape)); + auto find = [&](const absl::flat_hash_set& map, int64 d) { + return map.count(d) != 0; + }; + + auto insert = [&](absl::flat_hash_set& map, char d) { + CHECK(!find(map, d)); + map.insert(d); + }; + + for (auto d : x_config) { + insert(x_map, d); } - const int ndims = ShapeUtil::Rank(x_shape); - if (ndims < 2) { - return InvalidArgument( - "Arguments to BatchDot must have rank >= 2: got %d", ndims); + + for (auto d : y_config) { + insert(y_map, d); } - // The batch dimensions must be equal and the matrix dimensions must be - // valid. - std::vector batch_dimension_numbers; - for (int i = 0; i < ndims - 2; ++i) { - if (x_shape.dimensions(i) != y_shape.dimensions(i)) { - return InvalidArgument( - "Dimension %d of inputs to BatchDot must be equal: shapes %s vs %s", - i, ShapeUtil::HumanString(x_shape), - ShapeUtil::HumanString(y_shape)); - } - batch_dimension_numbers.push_back(i); + for (auto d : output_config) { + insert(output_map, d); } - int x_inner_dim = ndims - 1; - int y_inner_dim = ndims - 2; - if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { - return InvalidArgument( - "Dimensions %d and %d of arguments to BatchDot must be equal: " - "shapes %s vs %s", - x_inner_dim, y_inner_dim, ShapeUtil::HumanString(x_shape), - ShapeUtil::HumanString(y_shape)); + DotDimensionNumbers dnums; + std::vector lhs_outer_dims; + auto is_batch_dim = [&](int64 d) { + return find(x_map, d) && find(y_map, d) && find(output_map, d); + }; + auto is_contracting = [&](int64 d) { + return find(x_map, d) && find(y_map, d); + }; + auto rhs_dimension_number = [&](int64 d) { + return absl::c_find(y_config, d) - y_config.begin(); + }; + for (int64 i = 0; i < x_rank; ++i) { + auto dim_name = x_config[i]; + if (is_batch_dim(dim_name)) { + dnums.add_lhs_batch_dimensions(i); + dnums.add_rhs_batch_dimensions(rhs_dimension_number(dim_name)); + } else if (is_contracting(dim_name)) { + dnums.add_lhs_contracting_dimensions(i); + dnums.add_rhs_contracting_dimensions(rhs_dimension_number(dim_name)); + } else { + lhs_outer_dims.push_back(i); + } } - // Check for zero lhs/rhs dim size. - if (ShapeUtil::IsZeroElementArray(x_shape) || - ShapeUtil::IsZeroElementArray(y_shape)) { - std::vector dimensions(batch_dimension_numbers.size()); - for (int i = 0; i < batch_dimension_numbers.size(); ++i) { - dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); + std::vector rhs_outer_dims; + for (int64 i = 0; i < y_rank; ++i) { + auto dim_name = y_config[i]; + if (!is_batch_dim(dim_name) && !is_contracting(dim_name)) { + rhs_outer_dims.push_back(i); } - int x_outer_dim = ndims - 2; - int y_outer_dim = ndims - 1; - dimensions.push_back(x_shape.dimensions(x_outer_dim)); - dimensions.push_back(y_shape.dimensions(y_outer_dim)); - return Broadcast( - ConstantLiteral(builder, LiteralUtil::Zero(x_shape.element_type())), - dimensions); + } + + auto output_dimension_number = [&](char d) { + return absl::c_find(output_config, d) - output_config.begin(); + }; + + std::vector output_dims; + output_dims.reserve(output_rank); + for (auto d : dnums.lhs_batch_dimensions()) { + output_dims.push_back(output_dimension_number(x_config[d])); + } + for (auto d : lhs_outer_dims) { + output_dims.push_back(output_dimension_number(x_config[d])); + } + for (auto d : rhs_outer_dims) { + output_dims.push_back(output_dimension_number(y_config[d])); + } + + std::vector transpose_dims(output_rank); + for (int64 i = 0; i < output_rank; ++i) { + transpose_dims[output_dims[i]] = i; } PrecisionConfig precision_proto; precision_proto.add_operand_precision(precision); precision_proto.add_operand_precision(precision); + return Transpose(DotGeneral(x, y, dnums, &precision_proto), transpose_dims); + }); +} + +XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y)); + + // The batch dimensions must be equal and the matrix dimensions must be + // valid. + std::vector batch_dimension_numbers; + const int ndims = x_shape.rank(); + batch_dimension_numbers.reserve(ndims - 2); + for (int i = 0; i < ndims - 2; ++i) { + batch_dimension_numbers.push_back(i); + } + std::vector x_config = batch_dimension_numbers; + x_config.push_back(ndims - 2); + x_config.push_back(ndims); + std::vector y_config = batch_dimension_numbers; + y_config.push_back(ndims); + y_config.push_back(ndims - 1); + std::vector output_config = batch_dimension_numbers; + output_config.push_back(ndims - 2); + output_config.push_back(ndims - 1); + return Einsum(x, x_config, y, y_config, output_config, precision); + }); +} - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); - dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); - for (auto batch_dimension_number : batch_dimension_numbers) { - dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); - dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); +StatusOr, 3>> ParseEinsumString( + absl::string_view einsum_config) { + std::array, 3> einsum_config_numeric; + std::vector main_split = + absl::StrSplit(einsum_config, ','); + + if (main_split.size() != 2) { + return InvalidArgument("Expected one \",\" in einsum_config."); + } + + auto maybe_invalid_character = [](char d) { + if (absl::ascii_isalpha(d)) { + return Status::OK(); } + if (d == '.') { + return InvalidArgument("Unsupported \"...\" or \".\" in einsum config."); + } + return InvalidArgument("Unexpected character in einsum config."); + }; - return DotGeneral(x, y, dot_dnums, &precision_proto); + auto& x_config = einsum_config_numeric[0]; + x_config.reserve(main_split[0].size()); + for (auto d : main_split[0]) { + TF_RETURN_IF_ERROR(maybe_invalid_character(d)); + x_config.push_back(static_cast(d)); + } + std::vector y_output_split = + absl::StrSplit(main_split[1], "->"); + if (y_output_split.size() != 2) { + return InvalidArgument("Expected one \"->\" in einsum_config."); + } + auto& y_config = einsum_config_numeric[1]; + y_config.reserve(y_output_split[0].size()); + for (auto d : y_output_split[0]) { + TF_RETURN_IF_ERROR(maybe_invalid_character(d)); + y_config.push_back(static_cast(d)); + } + auto& output_config = einsum_config_numeric[2]; + output_config.reserve(y_output_split[1].size()); + for (auto d : y_output_split[1]) { + TF_RETURN_IF_ERROR(maybe_invalid_character(d)); + output_config.push_back(static_cast(d)); + } + return einsum_config_numeric; +} + +XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto einsum_config_numeric, + ParseEinsumString(einsum_config)); + return Einsum(x, einsum_config_numeric[0], y, einsum_config_numeric[1], + einsum_config_numeric[2], precision); }); } @@ -170,7 +324,7 @@ XlaOp TransposeInMinorDims(XlaOp x) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); TF_RET_CHECK(n_dims >= 2); std::vector permutation(n_dims); std::iota(permutation.begin(), permutation.end(), 0); diff --git a/tensorflow/compiler/xla/client/lib/matrix.h b/tensorflow/compiler/xla/client/lib/matrix.h index 8856f99c7a0fee8f315aac11fab392cf5536f57b..491f1eab4cbffbbf9df70d4c35a61351df3e98aa 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.h +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -16,7 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -31,6 +35,10 @@ XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n); // diagonal elements (i.e., with indices [..., i, i]). XlaOp GetMatrixDiagonal(XlaOp x); +// Returns a lower-triangular mask, i.e., true below the `diagonal`-th diagonal +// and false above that diagonal. +XlaOp TriangleMask(XlaOp x, int diagonal); + // Get the upper or lower triangle part of the last two dimensions XlaOp Triangle(XlaOp x, bool lower); @@ -61,6 +69,40 @@ xla::XlaOp BatchDot( xla::XlaOp x, xla::XlaOp y, xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); +// Parse an einsum string into dimension numbers: +// "ab,cb->ac" +// becomes: +// {{0, 1},{2, 1},{0, 2}} +// +// NOTE: This function is meant for testing, there is no need to call it +// directly. + +StatusOr, 3>> ParseEinsumString( + absl::string_view einsum_config); + +// Determine if each dimension label is in at least two inputs. +// +// NOTE: This function is meant for testing, there is no need to call it +// directly. +Status ValidateEinsumNumericDimensions(absl::Span x_config, + absl::Span y_config, + absl::Span output_config); + +// Supports two operand einsum notation like "ab,cb->ac". +xla::XlaOp Einsum( + xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); + +// Same as above but supporting numeric labels on dimensins. So "ab,cb->ac" +// becomes: +// x_config = {0, 1} +// y_config = {2, 1} +// output_config = {0, 2} +xla::XlaOp Einsum( + xla::XlaOp x, absl::Span x_config, xla::XlaOp y, + absl::Span y_config, absl::Span output_config, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); + // Transposes a stack of matrices `x` by swapping the last two dimensions. xla::XlaOp TransposeInMinorDims(xla::XlaOp x); diff --git a/tensorflow/compiler/xla/client/lib/matrix_test.cc b/tensorflow/compiler/xla/client/lib/matrix_test.cc index 0593a7517ac125ca8dc5395cee76f6bc23232cd3..79cf529ee94b044ee0af788522200cd28c778997 100644 --- a/tensorflow/compiler/xla/client/lib/matrix_test.cc +++ b/tensorflow/compiler/xla/client/lib/matrix_test.cc @@ -15,8 +15,11 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -101,5 +104,78 @@ XLA_TEST_F(MatrixTest, RowBatchDot) { ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, {a_data.get(), row_data.get(), index_data.get()}); } + +XLA_TEST_F(MatrixTest, Einsum) { + XlaBuilder builder(TestName()); + + int n = 4; + + XlaOp a, row, index; + auto a_data = + CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); + auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, + "row", &builder, &row); + // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). + auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); + + auto l_index = DynamicSliceInMinorDims( + a, {index, ConstantR0(&builder, 0)}, {1, n}); + Einsum(l_index, row, "abc,adc->abd"); + + ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, + {a_data.get(), row_data.get(), index_data.get()}); +} + +XLA_TEST_F(MatrixTest, ParseEinsumString) { + auto to_vec = [](absl::string_view s) { + std::vector v; + v.reserve(s.size()); + for (auto c : s) { + v.push_back(int64{c}); + } + return v; + }; + + auto to_string = [&](absl::string_view x, absl::string_view y, + absl::string_view o) { + return absl::StrCat(x, ",", y, "->", o); + }; + + std::vector> good_test_cases = {{"ab", "bc", "ac"}, + {"Bab", "Bbc", "Bac"}, + {"ab", "cd", "dcba"}, + {"abc", "abd", "cbd"}}; + for (auto test_case : good_test_cases) { + auto parse_result_or_status = + ParseEinsumString(to_string(test_case[0], test_case[1], test_case[2])); + EXPECT_TRUE(parse_result_or_status.status().ok()); + auto parse_result = parse_result_or_status.ValueOrDie(); + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(parse_result[i], to_vec(test_case[i])); + } + EXPECT_TRUE(ValidateEinsumNumericDimensions( + parse_result[0], parse_result[1], parse_result[2]) + .ok()); + } + + std::vector einsum_strings_that_fail_parsing = { + "", "a", "ab->ba", "ab,bc,cd->ad", "a...b,bc->a...c"}; + for (auto test_case : einsum_strings_that_fail_parsing) { + auto parse_result_or_status = ParseEinsumString(test_case); + EXPECT_FALSE(parse_result_or_status.status().ok()); + } + + std::vector einsum_strings_that_fail_numeric_validation = { + "a,b->c", "ab,bc->acd", "abz,bc->ac", "ab,bcz->ac"}; + for (auto test_case : einsum_strings_that_fail_numeric_validation) { + auto parse_result_or_status = ParseEinsumString(test_case); + EXPECT_TRUE(parse_result_or_status.status().ok()); + auto parse_result = parse_result_or_status.ValueOrDie(); + EXPECT_FALSE(ValidateEinsumNumericDimensions( + parse_result[0], parse_result[1], parse_result[2]) + .ok()); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/qr.cc b/tensorflow/compiler/xla/client/lib/qr.cc index 72ca653173b78d9338f632c41779f2a30db1e978..640412ec8bcffd2565b11ba25b87f6bf6438d848 100644 --- a/tensorflow/compiler/xla/client/lib/qr.cc +++ b/tensorflow/compiler/xla/client/lib/qr.cc @@ -154,7 +154,7 @@ struct QRBlockResult { StatusOr QRBlock(XlaOp a, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int num_dims = ShapeUtil::Rank(a_shape); + const int num_dims = a_shape.rank(); if (num_dims < 2) { return InvalidArgument("Argument to QR must have rank >= 2; got shape %s", a_shape.ToString()); @@ -325,7 +325,7 @@ StatusOr QRDecomposition( PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int num_dims = ShapeUtil::Rank(a_shape); + const int num_dims = a_shape.rank(); if (num_dims < 2) { return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s", a_shape.ToString()); diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc index f8c7df3ff5189c817202eaf39adb572f7e232ec2..77145ba7d4c72435450d3e33d57b2507eb84d2fc 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.cc +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace xla { @@ -26,7 +27,7 @@ XlaOp SliceInMinorDims(XlaOp x, absl::Span start, TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); TF_RET_CHECK(n_minor_dims <= n_dims); auto major_dims = AsInt64Slice(shape.dimensions()) .subspan( @@ -51,17 +52,17 @@ XlaOp SliceInMinorDims(XlaOp x, absl::Span start, XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = shape.rank(); + TF_RET_CHECK(start.size() == n_dims); + // TODO(phawkins): make int64 work on all backends, remove the int32 cast. std::vector start_as_int32(start.begin(), start.end()); - auto start_constant = ConstantR1(builder, start_as_int32); - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); - TF_ASSIGN_OR_RETURN(Shape start_constant_shape, - builder->GetShape(start_constant)); - const int64 start_length = - ShapeUtil::GetDimension(start_constant_shape, -1); - TF_RET_CHECK(start_length == n_dims); - return DynamicUpdateSlice(x, update, start_constant); + std::vector start_ops(start.size()); + for (int i = 0; i < start.size(); ++i) { + start_ops[i] = ConstantR0(builder, start_as_int32[i]); + } + return DynamicUpdateSlice(x, update, start_ops); }); } @@ -70,7 +71,7 @@ XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); const int64 n_minor_dims = start.size(); TF_RET_CHECK(n_minor_dims <= n_dims); std::vector padded_start(n_dims, 0); @@ -90,18 +91,17 @@ std::vector ConcatVectors(absl::Span xs, return output; } -XlaOp PrependZerosInMajorDims(XlaOp x, absl::Span starts) { +StatusOr> PrependZerosInMajorDims( + XlaOp x, absl::Span starts) { XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); - auto zero = Reshape(ConstantR0(builder, 0), {1}); - std::vector padded_starts(n_dims, zero); - for (int i = 0; i < starts.size(); ++i) { - padded_starts[n_dims - starts.size() + i] = Reshape(starts[i], {1}); - } - return ConcatInDim(builder, padded_starts, 0); - }); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = shape.rank(); + auto zero = ConstantR0(builder, 0); + std::vector padded_starts(n_dims, zero); + for (int i = 0; i < starts.size(); ++i) { + padded_starts[n_dims - starts.size() + i] = starts[i]; + } + return padded_starts; } } // namespace @@ -111,7 +111,7 @@ XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); int64 n_minor_dims = starts.size(); TF_RET_CHECK(n_minor_dims == sizes.size()); TF_RET_CHECK(n_minor_dims <= n_dims); @@ -119,7 +119,7 @@ XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, .subspan( /*pos=*/0, /*len=*/n_dims - sizes.size()); - auto padded_starts = PrependZerosInMajorDims(x, starts); + TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts)); auto padded_sizes = ConcatVectors(major_dims, sizes); return DynamicSlice(x, padded_starts, padded_sizes); }); @@ -127,8 +127,11 @@ XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, absl::Span starts) { - auto padded_starts = PrependZerosInMajorDims(x, starts); - return DynamicUpdateSlice(x, update, padded_starts); + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts)); + return DynamicUpdateSlice(x, update, padded_starts); + }); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc index 27ff36c7491ab8397d46f3a49493ff2b904deb2d..0fbd138aca1e86f219d0459086fc09d20844f135 100644 --- a/tensorflow/compiler/xla/client/lib/sorting_test.cc +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -77,7 +77,7 @@ XLA_TEST_F(SortingTest, TopKFullSort) { auto x = ConstantR1(&builder, inputs); xla::GetTupleElement(xla::TopK(x, kSize), 0); - std::sort(inputs.begin(), inputs.end(), std::greater()); + absl::c_sort(inputs, std::greater()); ComputeAndCompareR1(&builder, inputs, {}); } diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 5db9d10dff4c50d71cde934b3f3c345bee571f29..9f520bcdadfabc8ca9f9ee82b20804fd2c50d1db 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -34,7 +34,7 @@ namespace { // specified shape. In case of a (nested) tuple shape this is the total byte // size of all sub-shapes within the tuple. int64 DataSizeOfShape(const Shape& shape) { - if (ShapeUtil::IsArray(shape)) { + if (shape.IsArray()) { return ShapeUtil::ByteSizeOf(shape); } @@ -47,7 +47,7 @@ int64 DataSizeOfShape(const Shape& shape) { // Creates a XlaOp for an op what generates fake data with the given shape. XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { - if (ShapeUtil::IsArray(shape)) { + if (shape.IsArray()) { return Broadcast( ConstantLiteral(builder, LiteralUtil::One(shape.element_type())), AsInt64Slice(shape.dimensions())); diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve.cc b/tensorflow/compiler/xla/client/lib/triangular_solve.cc index 159e0c82dc4ff123533b65baac99388591c400d7..ba7fde118fde990fbb4aa9a34dd0f0e67ff5a93b 100644 --- a/tensorflow/compiler/xla/client/lib/triangular_solve.cc +++ b/tensorflow/compiler/xla/client/lib/triangular_solve.cc @@ -38,7 +38,7 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a)); - int ndims = ShapeUtil::Rank(shape); + int ndims = shape.rank(); int64 n = ShapeUtil::GetDimension(shape, -1); int64 num_blocks = n / block_size; @@ -140,9 +140,7 @@ XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, // zero (which can happen if the last block was padded) otherwise it will // introduce nans which will propagate auto diags = GetMatrixDiagonal(diag_blocks); - TF_ASSIGN_OR_RETURN(Shape diags_shape, builder->GetShape(diags)); - auto one = ScalarLike(diags, 1); - auto ones = Broadcast(one, AsInt64Slice(diags_shape.dimensions())); + auto ones = FullLike(diags, 1); diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags); auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2}); @@ -165,10 +163,10 @@ XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, // The first or last diagonal element should be set to 1 instead of -1 // though, since we never update it auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1}); - auto start_index = (lower) ? 0 : block_size - 1; - auto output_block = DynamicUpdateSlice( - neg_identity, pos_one, - /*start_indices=*/ConstantR1(builder, 2, start_index)); + auto start_index = ConstantR0(builder, (lower) ? 0 : block_size - 1); + auto output_block = + DynamicUpdateSlice(neg_identity, pos_one, + /*start_indices=*/{start_index, start_index}); // Broadcast diag([1, -1, -1, ...]) to every block XlaOp output = Broadcast(output_block, @@ -211,12 +209,10 @@ XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, auto body_out = GetTupleElement(input_tuple, 1); auto body_input = GetTupleElement(input_tuple, 2); - auto zero = ConstantR1(bodyb.get(), 1, 0); + auto zero = ConstantR0(bodyb.get(), 0); auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i; - auto start_indices = - ConcatInDim(bodyb.get(), {zero, Reshape(j, {1}), zero}, 0); auto input_row = - DynamicSlice(body_input, start_indices, + DynamicSlice(body_input, {zero, j, zero}, /*slice_sizes=*/{num_blocks, 1, block_size}); // We want -L21 L11^{-1} @@ -230,7 +226,7 @@ XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, precision_proto.add_operand_precision(precision); auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); - body_out = DynamicUpdateSlice(body_out, update, start_indices); + body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero}); auto next_i = i + ScalarLike(i, 1); Tuple(bodyb.get(), {next_i, body_out, body_input}); @@ -262,7 +258,7 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, int64 block_size = ShapeUtil::GetDimension(blocks_shape, -1); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - int64 ndims = ShapeUtil::Rank(a_shape); + int64 ndims = a_shape.rank(); int64 n = ShapeUtil::GetDimension(a_shape, -1); int64 num_blocks = n / block_size + (n % block_size != 0); int64 m_dim = (left_side) ? -1 : -2; @@ -356,13 +352,13 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); - if (ShapeUtil::Rank(a_shape) != ShapeUtil::Rank(b_shape)) { + if (a_shape.rank() != b_shape.rank()) { return InvalidArgument( "Arguments to TriangularSolve have shapes with different ranks: " "%s vs. %s", ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape)); } - const int64 ndims = ShapeUtil::Rank(a_shape); + const int64 ndims = a_shape.rank(); if (ndims < 2) { return InvalidArgument( "Arguments to TriangularSolve was rank %d but must have rank >= 2.", @@ -417,6 +413,11 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a, conjugate_a, precision); + // Mask off the ignored elements of the triangular matrix a. + // TODO(phawkins): it would probably be preferable to perform this masking + // block by block inside SolveWithInvertedDiagonalBlocks. + a = Triangle(a, lower); + // We now find the solution using GEMMs auto x = SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, lower, diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc b/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc index 3fea627e6a8c30b6f06fa61751aad386ec543843..284a2e9d183a6a7923fb59ac134ce3b3a3a96e35 100644 --- a/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" @@ -37,12 +38,20 @@ namespace { using TriangularSolveTest = ClientLibraryTestBase; using TriangularSolveLeftLookingTest = ClientLibraryTestBase; +static constexpr float kNan = std::numeric_limits::quiet_NaN(); + Array2D AValsLower() { - return {{2, 0, 0, 0}, {3, 6, 0, 0}, {4, 7, 9, 0}, {5, 8, 10, 11}}; + return {{2, kNan, kNan, kNan}, + {3, 6, kNan, kNan}, + {4, 7, 9, kNan}, + {5, 8, 10, 11}}; } Array2D AValsUpper() { - return {{2, 3, 4, 5}, {0, 6, 7, 8}, {0, 0, 9, 10}, {0, 0, 0, 11}}; + return {{2, 3, 4, 5}, + {kNan, 6, 7, 8}, + {kNan, kNan, 9, 10}, + {kNan, kNan, kNan, 11}}; } Array2D BValsRight() { @@ -53,18 +62,20 @@ Array2D BValsLeft() { return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; } +static constexpr complex64 kNanC64 = complex64(kNan, kNan); + Array2D AValsLowerComplex() { - return {{2, 0, 0, 0}, - {complex64(3, 1), 6, 0, 0}, - {4, complex64(7, 2), 9, 0}, + return {{2, kNanC64, kNanC64, kNanC64}, + {complex64(3, 1), 6, kNanC64, kNanC64}, + {4, complex64(7, 2), 9, kNanC64}, {5, 8, complex64(10, 3), 11}}; } Array2D AValsUpperComplex() { return {{2, 3, complex64(4, 3), 5}, - {0, 6, complex64(7, 2), 8}, - {0, 0, complex64(9, 1), 10}, - {0, 0, 0, 11}}; + {kNanC64, 6, complex64(7, 2), 8}, + {kNanC64, kNanC64, complex64(9, 1), 10}, + {kNanC64, kNanC64, kNanC64, 11}}; } Array2D BValsRightComplex() { @@ -367,5 +378,70 @@ XLA_TEST_F(TriangularSolveTest, BatchedLeftUpper) { ErrorSpec(1e-2, 1e-2)); } +struct TriangularSolveTestSpec { + int m, n; // A is mxm, B is mxn + bool left_side; + bool lower; + bool transpose_a; +}; + +class TriangularSolveParametricTest + : public ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(TriangularSolveParametricTest, Random) { + TriangularSolveTestSpec spec = GetParam(); + + XlaBuilder builder(TestName()); + + Array2D avals(spec.m, spec.m); + avals.FillRandom(1.0); + for (int i = 0; i < spec.m; ++i) { + avals(i, i) += 10; + } + + std::pair bdims = spec.left_side ? std::make_pair(spec.m, spec.n) + : std::make_pair(spec.n, spec.m); + Array2D bvals(bdims.first, bdims.second); + bvals.FillRandom(1.0); + + XlaOp a, b; + auto a_data = CreateR2Parameter(avals, 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(bvals, 1, "b", &builder, &b); + auto x = TriangularSolve(a, b, spec.left_side, spec.lower, spec.transpose_a, + /*conjugate_a=*/false, + /*block_size=*/3); + auto a_tri = Triangle(a, spec.lower); + a_tri = MaybeTransposeInMinorDims(a_tri, spec.transpose_a); + if (spec.left_side) { + BatchDot(a_tri, x); + } else { + BatchDot(x, a_tri); + } + + ComputeAndCompareR2(&builder, bvals, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + +std::vector TriangularSolveTests() { + std::vector specs; + for (int m : {5, 10}) { + for (int n : {5, 10}) { + for (bool left_side : {false, true}) { + for (bool lower : {false, true}) { + for (bool transpose_a : {false, true}) { + specs.push_back({m, n, left_side, lower, transpose_a}); + } + } + } + } + } + return specs; +} + +INSTANTIATE_TEST_SUITE_P(TriangularSolveParametricTestInstantiation, + TriangularSolveParametricTest, + ::testing::ValuesIn(TriangularSolveTests())); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 049cd15738a619294b19d5cf74ca514d7b4a00ad..48b5f94538f453785194bc434a91ee0a10c020c2 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -164,9 +164,8 @@ StatusOr LocalExecutable::Run( // ExecutableRunOptions.eigen_intra_op_thread_pool. // *) The thread pool used for XLA CPU ops is from // backend_->eigen_intra_op_thread_pool(). - ServiceExecutableRunOptions service_options( - run_options, backend_->StreamBorrower(), - backend_->eigen_intra_op_thread_pool()); + ServiceExecutableRunOptions service_options(run_options, + backend_->StreamBorrower()); if (executable_->dumping_snapshot()) { return ExecuteAndDump(&service_options, arguments); diff --git a/tensorflow/compiler/xla/client/sharding_builder.cc b/tensorflow/compiler/xla/client/sharding_builder.cc index fb9ea6ec3fc41d5e04ca125798a8199350470a44..b9bff06cbdbc3525eb19d5df885952c3971d9d6a 100644 --- a/tensorflow/compiler/xla/client/sharding_builder.cc +++ b/tensorflow/compiler/xla/client/sharding_builder.cc @@ -50,7 +50,7 @@ OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles) { OpSharding result; result.set_type(OpSharding::Type::OpSharding_Type_OTHER); - CHECK_EQ(ShapeUtil::Rank(tile_shape), 1); + CHECK_EQ(tile_shape.rank(), 1); std::vector dimensions(1, num_tiles); *result.mutable_tile_shape() = tile_shape.ToProto(); auto& tile_dimension = diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 622fc158e11161b5b1167ccb432f51775767e3a1..5c9f9f708883f458b67205058fc7c1e1e2ad02f5 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -29,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -192,9 +195,9 @@ StatusOr XlaBuilder::GetProgramShape(XlaOp root) const { } void XlaBuilder::IsConstantVisitor(const int64 op_handle, - std::set* visited, + absl::flat_hash_set* visited, bool* is_constant) const { - if (visited->count(op_handle) != 0 || !*is_constant) { + if (visited->contains(op_handle) || !*is_constant) { return; } @@ -208,11 +211,21 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, } // TODO(b/32495713): We aren't checking the called computations. break; + case HloOpcode::kGetDimensionSize: { + int64 dimension_number = instr.dimensions(0); + const HloInstructionProto& operand = + *(LookUpInstructionByHandle(instr.operand_ids(0)).ValueOrDie()); + Shape operand_shape(operand.shape()); + if (operand_shape.is_dynamic_dimension(dimension_number)) { + *is_constant = false; + } + break; + } // Non functional ops. case HloOpcode::kRng: case HloOpcode::kAllReduce: - // TODO(b/33009255): Implmement constant folding for cross replica sum. + // TODO(b/33009255): Implement constant folding for cross replica sum. case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kCall: @@ -244,6 +257,29 @@ Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num, int64 target_param_num, ShapeIndex target_param_index, int64 target_dim_num) { + bool param_exists = false; + for (HloInstructionProto& instr : instructions_) { + if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) && + instr.parameter_number() == target_param_num) { + param_exists = true; + Shape param_shape(instr.shape()); + Shape* param_shape_ptr = ¶m_shape; + for (int64 index : target_param_index) { + param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index); + } + param_shape_ptr->set_dynamic_dimension(target_dim_num, + /*is_dynamic=*/true); + *instr.mutable_shape() = param_shape.ToProto(); + } + } + + if (!param_exists) { + return InvalidArgument( + "Asked to mark parameter %lld as dynamic sized parameter, but the " + "doesn't exists", + target_param_num); + } + TF_RETURN_IF_ERROR(dynamic_parameter_binding_.Bind( DynamicParameterBinding::DynamicParameter{dynamic_size_param_num, dynamic_size_param_index}, @@ -263,29 +299,52 @@ XlaComputation XlaBuilder::BuildAndNoteError() { return build_status.ConsumeValueOrDie(); } -StatusOr XlaBuilder::Build() { +StatusOr XlaBuilder::Build(bool remove_dynamic_dimensions) { if (!first_error_.ok()) { string backtrace; first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); return AppendStatus(first_error_, backtrace); } - return Build(instructions_.back().id()); + return Build(instructions_.back().id(), remove_dynamic_dimensions); } -StatusOr XlaBuilder::Build(XlaOp root) { +StatusOr XlaBuilder::Build(XlaOp root, + bool remove_dynamic_dimensions) { if (root.builder_ != this) { return InvalidArgument("Given root operation is not in this computation."); } - return Build(root.handle()); + return Build(root.handle(), remove_dynamic_dimensions); } -StatusOr XlaBuilder::Build(int64 root_id) { +StatusOr XlaBuilder::Build(int64 root_id, + bool remove_dynamic_dimensions) { if (!first_error_.ok()) { string backtrace; first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); return AppendStatus(first_error_, backtrace); } + // TODO(b/121223198): XLA backend cannot handle dynamic dimensions yet, remove + // all dynamic dimensions before building xla program until we have support in + // the backend. + if (remove_dynamic_dimensions) { + std::function remove_dynamic_dimension = + [&](ShapeProto* shape) { + if (shape->tuple_shapes_size() != 0) { + for (int64 i = 0; i < shape->tuple_shapes_size(); ++i) { + remove_dynamic_dimension(shape->mutable_tuple_shapes(i)); + } + } + for (int64 i = 0; i < shape->dimensions_size(); ++i) { + shape->set_is_dynamic_dimension(i, false); + } + }; + + for (auto& instruction : instructions_) { + remove_dynamic_dimension(instruction.mutable_shape()); + } + } + HloComputationProto entry; SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId()); TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id)); @@ -310,7 +369,10 @@ StatusOr XlaBuilder::Build(int64 root_id) { module->add_computations()->Swap(&e.second); } module->add_computations()->Swap(&entry); - + if (!input_output_aliases_.empty()) { + TF_RETURN_IF_ERROR( + PopulateInputOutputAlias(module, program_shape, input_output_aliases_)); + } *(module->mutable_dynamic_parameter_binding()) = dynamic_parameter_binding_.ToProto(); @@ -323,6 +385,35 @@ StatusOr XlaBuilder::Build(int64 root_id) { return std::move(computation); } +/* static */ Status XlaBuilder::PopulateInputOutputAlias( + HloModuleProto* module, const ProgramShape& program_shape, + const std::vector& input_output_aliases) { + HloInputOutputAliasConfig config(program_shape.result()); + for (auto& alias : input_output_aliases) { + // The HloInputOutputAliasConfig does not do parameter validation as it only + // carries the result shape. Maybe it should be constructed with a + // ProgramShape to allow full validation. We will still get an error when + // trying to compile the HLO module, but would be better to have validation + // at this stage. + if (alias.param_number >= program_shape.parameters_size()) { + return InvalidArgument("Invalid parameter number %ld (total %ld)", + alias.param_number, + program_shape.parameters_size()); + } + const Shape& parameter_shape = program_shape.parameters(alias.param_number); + if (!ShapeUtil::IndexIsValid(parameter_shape, alias.param_index)) { + return InvalidArgument("Invalid parameter %ld index: %s", + alias.param_number, + alias.param_index.ToString().c_str()); + } + TF_RETURN_IF_ERROR(config.SetUpAlias( + alias.output_index, alias.param_number, alias.param_index, + HloInputOutputAliasConfig::AliasKind::kUserAlias)); + } + *module->mutable_input_output_alias() = config.ToProto(); + return Status::OK(); +} + StatusOr XlaBuilder::InDimBroadcast( const Shape& shape, const XlaOp& operand, absl::Span broadcast_dimensions) { @@ -343,7 +434,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); CHECK(ShapeUtil::IsScalar(operand_shape) || - ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)); + operand_shape.rank() == output_shape.rank()); Shape broadcast_shape = ShapeUtil::ChangeElementType(output_shape, operand_shape.element_type()); @@ -355,7 +446,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, // Do explicit broadcast for degenerate broadcast. std::vector broadcast_dimensions; std::vector reshaped_dimensions; - for (int i = 0; i < ShapeUtil::Rank(operand_shape); i++) { + for (int i = 0; i < operand_shape.rank(); i++) { if (operand_shape.dimensions(i) == output_shape.dimensions(i)) { broadcast_dimensions.push_back(i); reshaped_dimensions.push_back(operand_shape.dimensions(i)); @@ -398,8 +489,8 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, binop, lhs_shape, rhs_shape, broadcast_dimensions)); *instr.mutable_shape() = shape.ToProto(); - const int64 lhs_rank = ShapeUtil::Rank(lhs_shape); - const int64 rhs_rank = ShapeUtil::Rank(rhs_shape); + const int64 lhs_rank = lhs_shape.rank(); + const int64 rhs_rank = rhs_shape.rank(); XlaOp updated_lhs = lhs; XlaOp updated_rhs = rhs; @@ -410,17 +501,19 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, const Shape& from_shape = should_broadcast_lhs ? lhs_shape : rhs_shape; std::vector to_size; - for (int64 size : shape.dimensions()) { - to_size.push_back(size); + std::vector to_size_is_dynamic; + for (int i = 0; i < shape.rank(); i++) { + to_size.push_back(shape.dimensions(i)); + to_size_is_dynamic.push_back(shape.is_dynamic_dimension(i)); } - for (int64 from_dim = 0; from_dim < ShapeUtil::Rank(from_shape); - from_dim++) { + for (int64 from_dim = 0; from_dim < from_shape.rank(); from_dim++) { int64 to_dim = broadcast_dimensions[from_dim]; to_size[to_dim] = from_shape.dimensions(from_dim); + to_size_is_dynamic[to_dim] = from_shape.is_dynamic_dimension(from_dim); } - const Shape& broadcasted_shape = - ShapeUtil::MakeShape(from_shape.element_type(), to_size); + const Shape& broadcasted_shape = ShapeUtil::MakeShape( + from_shape.element_type(), to_size, to_size_is_dynamic); TF_ASSIGN_OR_RETURN( XlaOp broadcasted_operand, InDimBroadcast(broadcasted_shape, from, broadcast_dimensions)); @@ -458,18 +551,18 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, XlaOp updated_lhs = lhs; XlaOp updated_rhs = rhs; XlaOp updated_ehs = ehs; - if (!ShapeUtil::IsTuple(shape)) { - if (!ShapeUtil::IsTuple(lhs_shape) && + if (!shape.IsTuple()) { + if (!lhs_shape.IsTuple() && !ShapeUtil::SameDimensions(shape, lhs_shape)) { // lhs is being implicitly broadcasted. Change to explicit. TF_ASSIGN_OR_RETURN(updated_lhs, AddBroadcastSequence(shape, lhs)); } - if (!ShapeUtil::IsTuple(rhs_shape) && + if (!rhs_shape.IsTuple() && !ShapeUtil::SameDimensions(shape, rhs_shape)) { // rhs is being implicitly broadcasted. Change to explicit. TF_ASSIGN_OR_RETURN(updated_rhs, AddBroadcastSequence(shape, rhs)); } - if (!ShapeUtil::IsTuple(ehs_shape) && + if (!ehs_shape.IsTuple() && !ShapeUtil::SameDimensions(shape, ehs_shape)) { // ehs is being implicitly broadcasted. Change to explicit. TF_ASSIGN_OR_RETURN(updated_ehs, AddBroadcastSequence(shape, ehs)); @@ -563,10 +656,10 @@ XlaOp XlaBuilder::Broadcast(const XlaOp& operand, // output, so to append dimensions on the left the instruction's dimensions // should just be the n highest dimension numbers of the output shape where // n is the number of input dimensions. - const int64 operand_rank = ShapeUtil::Rank(operand_shape); + const int64 operand_rank = operand_shape.rank(); std::vector dimensions(operand_rank); for (int i = 0; i < operand_rank; ++i) { - dimensions[i] = i + ShapeUtil::Rank(shape) - operand_rank; + dimensions[i] = i + shape.rank() - operand_rank; } return InDimBroadcast(shape, operand, dimensions); }); @@ -579,8 +672,17 @@ XlaOp XlaBuilder::BroadcastInDim( TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); // Output shape, in the case of degenerate broadcast, the out_dim_size is // not necessarily the same as the dimension sizes of the output shape. - const auto& output_shape = + auto output_shape = ShapeUtil::MakeShape(operand_shape.element_type(), out_dim_size); + for (int i = 0; i < broadcast_dimensions.size(); i++) { + if (broadcast_dimensions[i] < 0 || + broadcast_dimensions[i] > out_dim_size.size()) { + return InvalidArgument("Broadcast dimension %lld is out of bound", + broadcast_dimensions[i]); + } + output_shape.set_dynamic_dimension(broadcast_dimensions[i], + operand_shape.is_dynamic_dimension(i)); + } TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape( operand_shape, output_shape, broadcast_dimensions) @@ -639,10 +741,10 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); - std::vector starts(ShapeUtil::Rank(shape), 0); + std::vector starts(shape.rank(), 0); std::vector limits(shape.dimensions().begin(), shape.dimensions().end()); - std::vector strides(ShapeUtil::Rank(shape), 1); + std::vector strides(shape.rank(), 1); starts[dimno] = start_index; limits[dimno] = limit_index; strides[dimno] = stride; @@ -660,7 +762,7 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, GetShape(start_indices)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicSliceShape( - operand_shape, start_indices_shape, slice_sizes)); + operand_shape, {start_indices_shape}, slice_sizes)); *instr.mutable_shape() = shape.ToProto(); for (int64 size : slice_sizes) { @@ -672,6 +774,34 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, }); } +XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, + absl::Span start_indices, + absl::Span slice_sizes) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + std::vector start_indices_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes, + GetOperandShapes(start_indices)); + absl::c_transform(start_indices_shapes, + std::back_inserter(start_indices_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferDynamicSliceShape( + operand_shape, start_indices_shapes, slice_sizes)); + *instr.mutable_shape() = shape.ToProto(); + + for (int64 size : slice_sizes) { + instr.add_dynamic_slice_sizes(size); + } + + std::vector operands = {operand}; + operands.insert(operands.end(), start_indices.begin(), start_indices.end()); + return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands); + }); +} + XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -681,13 +811,38 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update)); TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, GetShape(start_indices)); + TF_ASSIGN_OR_RETURN( + Shape shape, ShapeInference::InferDynamicUpdateSliceShape( + operand_shape, update_shape, {start_indices_shape})); + *instr.mutable_shape() = shape.ToProto(); + + return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, + {operand, update, start_indices}); + }); +} + +XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update)); + std::vector start_indices_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes, + GetOperandShapes(start_indices)); + absl::c_transform(start_indices_shapes, + std::back_inserter(start_indices_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicUpdateSliceShape( - operand_shape, update_shape, start_indices_shape)); + operand_shape, update_shape, start_indices_shapes)); *instr.mutable_shape() = shape.ToProto(); + std::vector operands = {operand, update}; + operands.insert(operands.end(), start_indices.begin(), start_indices.end()); return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, - {operand, update, start_indices}); + operands); }); } @@ -780,7 +935,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ","); std::vector new_sizes; - for (int i = 0; i < ShapeUtil::Rank(original_shape); ++i) { + for (int i = 0; i < original_shape.rank(); ++i) { if (i <= dimensions.front() || i > dimensions.back()) { new_sizes.push_back(original_shape.dimensions(i)); } else { @@ -808,10 +963,9 @@ XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& true_shape, GetShape(on_true)); TF_ASSIGN_OR_RETURN(const Shape& false_shape, GetShape(on_false)); - TF_RET_CHECK(ShapeUtil::IsTuple(true_shape) == - ShapeUtil::IsTuple(false_shape)); - HloOpcode opcode = ShapeUtil::IsTuple(true_shape) ? HloOpcode::kTupleSelect - : HloOpcode::kSelect; + TF_RET_CHECK(true_shape.IsTuple() == false_shape.IsTuple()); + HloOpcode opcode = + true_shape.IsTuple() ? HloOpcode::kTupleSelect : HloOpcode::kSelect; return TernaryOp(opcode, pred, on_true, on_false); }); } @@ -835,7 +989,7 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& tuple_shape, GetShape(tuple_data)); - if (!ShapeUtil::IsTuple(tuple_shape)) { + if (!tuple_shape.IsTuple()) { return InvalidArgument( "Operand to GetTupleElement() is not a tuple; got %s", ShapeUtil::HumanString(tuple_shape)); @@ -915,13 +1069,13 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, Status XlaBuilder::VerifyConvolution( const Shape& lhs_shape, const Shape& rhs_shape, const ConvolutionDimensionNumbers& dimension_numbers) const { - if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) { + if (lhs_shape.rank() != rhs_shape.rank()) { return InvalidArgument( "Convolution arguments must have same number of " "dimensions. Got: %s and %s", ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape)); } - int num_dims = ShapeUtil::Rank(lhs_shape); + int num_dims = lhs_shape.rank(); if (num_dims < 2) { return InvalidArgument( "Convolution expects argument arrays with >= 3 dimensions. " @@ -1150,7 +1304,7 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { *instr.mutable_shape() = infeed_instruction_shape.ToProto(); instr.set_infeed_config(config); - if (ShapeUtil::IsArray(shape) && sharding() && + if (shape.IsArray() && sharding() && sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) { // TODO(b/110793772): Support tiled array-shaped infeeds. return InvalidArgument( @@ -1226,7 +1380,7 @@ XlaOp XlaBuilder::InfeedWithToken(const XlaOp& token, const Shape& shape, *instr.mutable_shape() = infeed_instruction_shape.ToProto(); instr.set_infeed_config(config); - if (ShapeUtil::IsArray(shape) && sharding() && + if (shape.IsArray() && sharding() && sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) { // TODO(b/110793772): Support tiled array-shaped infeeds. return InvalidArgument( @@ -1339,7 +1493,7 @@ XlaOp XlaBuilder::AfterAll(absl::Span tokens) { for (int i = 0; i < tokens.size(); ++i) { const XlaOp& operand = tokens[i]; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - if (!ShapeUtil::IsToken(operand_shape)) { + if (!operand_shape.IsToken()) { return InvalidArgument( "All operands to AfterAll must be tokens; operand %d has shape %s", i, ShapeUtil::HumanString(operand_shape)); @@ -1582,7 +1736,7 @@ XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span values, *instr.mutable_shape() = shape.ToProto(); if (dimension == -1) { TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); - dimension = ShapeUtil::Rank(keys_shape) - 1; + dimension = keys_shape.rank() - 1; } instr.add_dimensions(dimension); std::vector operands{keys}; @@ -1652,12 +1806,12 @@ XlaOp XlaBuilder::Map(absl::Span operands, *instr.mutable_shape() = shape.ToProto(); Shape output_shape(instr.shape()); - const int64 output_rank = ShapeUtil::Rank(output_shape); + const int64 output_rank = output_shape.rank(); AddCalledComputation(computation, &instr); std::vector new_operands(operands.begin(), operands.end()); for (XlaOp& new_operand : new_operands) { TF_ASSIGN_OR_RETURN(Shape shape, GetShape(new_operand)); - const int64 rank = ShapeUtil::Rank(shape); + const int64 rank = shape.rank(); if (rank != output_rank) { TF_ASSIGN_OR_RETURN(new_operand, InDimBroadcast(output_shape, new_operand, {})); @@ -1866,7 +2020,7 @@ XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - std::vector all_dimnos(ShapeUtil::Rank(operand_shape)); + std::vector all_dimnos(operand_shape.rank()); std::iota(all_dimnos.begin(), all_dimnos.end(), 0); return Reduce(operand, init_value, computation, all_dimnos); }); @@ -2292,7 +2446,7 @@ XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token, ShapeUtil::HumanStringWithLayout(operand_shape)); } // TODO(b/111544877): Support tuple shapes. - if (!ShapeUtil::IsArray(operand_shape)) { + if (!operand_shape.IsArray()) { return InvalidArgument("SendToHost only supports array shapes, shape: %s", ShapeUtil::HumanString(operand_shape)); } @@ -2332,7 +2486,7 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, } // TODO(b/111544877): Support tuple shapes. - if (!ShapeUtil::IsArray(shape)) { + if (!shape.IsArray()) { return InvalidArgument( "RecvFromHost only supports array shapes, shape: %s", ShapeUtil::HumanString(shape)); @@ -2385,7 +2539,7 @@ StatusOr XlaBuilder::IsConstant(const XlaOp& operand) const { TF_RETURN_IF_ERROR(LookUpInstruction(operand).status()); bool is_constant = true; - std::set visited; + absl::flat_hash_set visited; IsConstantVisitor(operand.handle(), &visited, &is_constant); return is_constant; } @@ -2432,21 +2586,58 @@ StatusOr XlaBuilder::BuildConstantSubGraph( worklist.pop(); TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto, LookUpInstructionByHandle(handle)); - for (int64 id : instr_proto->operand_ids()) { - if (related_ops.insert(id).second) { - worklist.push(id); + + if (instr_proto->opcode() == + HloOpcodeString(HloOpcode::kGetDimensionSize)) { + // At this point, BuildConstantSubGraph should never encounter a + // GetDimensionSize with a dynamic dimension. IsConstant check would have + // failed at the beginning of this function. + // + // Replace GetDimensionSize with a Constant representing the static bound + // of the shape. + int64 dimension = instr_proto->dimensions(0); + int64 operand_handle = instr_proto->operand_ids(0); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, + LookUpInstructionByHandle(operand_handle)); + + TF_RET_CHECK(!operand_proto->shape().is_dynamic_dimension(dimension)); + auto constant_dimension_size = + static_cast(operand_proto->shape().dimensions(dimension)); + + Literal literal = LiteralUtil::CreateR0(constant_dimension_size); + + HloInstructionProto const_instr; + *const_instr.mutable_shape() = literal.shape().ToProto(); + *const_instr.mutable_literal() = literal.ToProto(); + *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant); + + const_instr.set_id(handle); + *const_instr.mutable_name() = + GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id()); + *entry.add_instructions() = + const_instr; // Add to the result constant graph. + } else { + for (int64 id : instr_proto->operand_ids()) { + if (related_ops.insert(id).second) { + worklist.push(id); + } + } + for (int64 called_id : instr_proto->called_computation_ids()) { + related_calls.insert(called_id); } - } - for (int64 called_id : instr_proto->called_computation_ids()) { - related_calls.insert(called_id); } } // Add related ops to the computation. for (int64 id : related_ops) { - auto* instr = entry.add_instructions(); TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src, LookUpInstructionByHandle(id)); + + if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize)) { + continue; + } + auto* instr = entry.add_instructions(); + *instr = *instr_src; // Ensures that the instruction names are unique among the graph. const string& new_name = @@ -2719,12 +2910,21 @@ XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes) { return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); } +XlaOp DynamicSlice(const XlaOp& operand, absl::Span start_indices, + absl::Span slice_sizes) { + return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); +} XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices) { return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); } +XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices) { + return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); +} + XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, int64 dimension) { return builder->ConcatInDim(operands, dimension); diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 6e9b025e5d70c03e9f4c7e7fbc89976f314d48d7..3bd6d42363664721ee4c15c8dc4fc75a42d0591b 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -197,11 +197,19 @@ class XlaBuilder { // status. Note that all ops that have been enqueued will be moved to the // computation being returned. The root of the computation will be the last // added operation. - StatusOr Build(); + // + // `remove_dynamic_dimensions` tells the builder whether to remove the + // dyanmic dimensions information in all ops. + // + // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the + // dynamic dimensions information when XLA backend can handle dynamic + // dimensions. + StatusOr Build(bool remove_dynamic_dimensions = true); // Overload of Build which specifies a particular root instruction for the // computation. - StatusOr Build(XlaOp root); + StatusOr Build(XlaOp root, + bool remove_dynamic_dimensions = true); // Builds the computation with the requested operations, or notes an error in // the parent XlaBuilder and returns an empty computation if building failed. @@ -269,6 +277,10 @@ class XlaBuilder { // and its real dynamic size is represented by `dynamic_param_index` in // parameter `dynamic_param_num`. // + // Note that this should be called before the dynamic parameters are used to + // create other operations, otherwise created operations won't have the + // dynamic dimensions information. + // // TODO(b/119520625): Remove this API once we have more dynamic shape infra // ready. Status SetDynamicBinding(int64 dynamic_size_param_num, @@ -276,9 +288,24 @@ class XlaBuilder { int64 target_param_num, ShapeIndex target_param_index, int64 target_dim_num); + // Adds a new input/output alias. Since the input/ouput shape information are + // not available until the computation is built, and eventual error in the + // arguments of this API will be detected only at computation Build() time. + void SetUpAlias(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index) { + input_output_aliases_.push_back({output_index, param_number, param_index}); + } + private: + // Describes an input/output alias as inserted by the SetUpAlias() API. + struct InputOutputAlias { + ShapeIndex output_index; + int64 param_number; + ShapeIndex param_index; + }; + // Build helper which takes the id of the root operation.. - StatusOr Build(int64 root_id); + StatusOr Build(int64 root_id, bool remove_dynamic_dimensions); // Description for the methods below can be found in the corresponding public // functions section in this file. @@ -344,11 +371,18 @@ class XlaBuilder { XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); + ABSL_DEPRECATED("Use span-of-indices form instead") XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes); + XlaOp DynamicSlice(const XlaOp& operand, + absl::Span start_indices, + absl::Span slice_sizes); + ABSL_DEPRECATED("Use span-of-indices form instead") XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); + XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices); XlaOp ConcatInDim(absl::Span operands, int64 dimension); @@ -712,7 +746,8 @@ class XlaBuilder { // operation such as `RngNormal` or `Infeed`. The visitor walks the // computation starting at a given operation and sets is_constant to false iff // a parameter or stateful operation is encountered. - void IsConstantVisitor(const int64 op_handle, std::set* visited, + void IsConstantVisitor(const int64 op_handle, + absl::flat_hash_set* visited, bool* is_constant) const; // Checks bounds for convolution parameters. @@ -730,6 +765,12 @@ class XlaBuilder { int64 GetNextId() { return ++next_id_; } + // Populates the module with the input/output alias information stored within + // the input_output_aliases vector. + static Status PopulateInputOutputAlias( + HloModuleProto* module, const ProgramShape& program_shape, + const std::vector& input_output_aliases); + string name_; // Name to use for the built computation. // The next sequential ID for every instruction/computation contained within @@ -749,6 +790,9 @@ class XlaBuilder { // Dynamic parameter configuration of this computation. DynamicParameterBinding dynamic_parameter_binding_; + // Holds the input/output alias information populated by the SetUpAlias() API. + std::vector input_output_aliases_; + // A map from XlaOp::Handle to the index in the instructions_ vector where the // instruction is held. absl::flat_hash_map handle_to_index_; @@ -850,9 +894,14 @@ class XlaBuilder { friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes); + friend XlaOp DynamicSlice(const XlaOp& operand, + absl::Span start_indices, + absl::Span slice_sizes); friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); + friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices); friend XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, int64 dimension); @@ -1294,10 +1343,15 @@ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, // The size of the slice in each dimension is passed in 'slice_sizes', // which specify the end point of exclusive slice intervals in each // dimension [start, start + size). -// The shape of 'start_indices' must be rank == 1, with dimension size -// equal to the rank of the 'operand'. +// The shape of each element of 'start_indices' must be scalar, with the span +// size equal to the rank of the 'operand'. All elements of 'start_indices' must +// have the same shape. // Slice index calculations are computed modulo input dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. +XlaOp DynamicSlice(const XlaOp& operand, absl::Span start_indices, + absl::Span slice_sizes); + +ABSL_DEPRECATED("Use span-of-indices form instead") XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes); @@ -1313,10 +1367,15 @@ XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] // [7 8 9] [7 8 9 ] // -// The shape of 'start_indices' must be rank == 1, with dimension size -// equal to the rank of the 'operand'. +// The shape of each element of 'start_indices' must be scalar, with the span +// size equal to the rank of the 'operand'. All elements of 'start_indices' must +// have the same shape. // Slice index calculations are computed modulo update dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. +XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices); + +ABSL_DEPRECATED("Use span-of-indices form instead") XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index b3f5be300d3f15397ad33858a6a9cab5f6029688..098165000a29cb28cb0ef906dbdb1ff9ae2f24e8 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -39,7 +40,8 @@ using ::testing::HasSubstr; class XlaBuilderTest : public ::testing::Test { protected: StatusOr> BuildHloModule(XlaBuilder* b) { - TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build()); + TF_ASSIGN_OR_RETURN(XlaComputation computation, + b->Build(/*remove_dynamic_dimensions=*/false)); const HloModuleProto& proto = computation.proto(); TF_ASSIGN_OR_RETURN(const auto& config, HloModule::CreateModuleConfigFromProto( @@ -50,7 +52,8 @@ class XlaBuilderTest : public ::testing::Test { // Overload which explicitly specifies the root instruction. StatusOr> BuildHloModule(XlaBuilder* b, XlaOp root) { - TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build(root)); + TF_ASSIGN_OR_RETURN(XlaComputation computation, + b->Build(root, /*remove_dynamic_dimensions=*/false)); const HloModuleProto& proto = computation.proto(); TF_ASSIGN_OR_RETURN(const auto& config, HloModule::CreateModuleConfigFromProto( @@ -446,6 +449,417 @@ TEST_F(XlaBuilderTest, ProtoMatches) { EXPECT_EQ(c0_string, c1_string); } +TEST_F(XlaBuilderTest, DynamicParameter) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(F32, {6})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + Parameter(&b, 1, ShapeUtil::MakeShape(U32, {}), "p1"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/1, + /*dynamic_size_param_index=*/{}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/p0)); + const Shape& param_shape = module->entry_computation() + ->parameter_instruction(0) + ->shape() + .tuple_shapes(1); + EXPECT_TRUE(param_shape.is_dynamic_dimension(0)); +} + +TEST_F(XlaBuilderTest, DynamicUnary) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + Neg(gte); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(result_shape.is_dynamic_dimension(0)); +} + +TEST_F(XlaBuilderTest, DynamicBinary) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(F32, {5}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p0, 1); + Add(gte0, gte1); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(result_shape.is_dynamic_dimension(0)); +} + +TEST_F(XlaBuilderTest, DynamicBinaryHasBroadcast) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5, 4}), ShapeUtil::MakeShape(F32, {5}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p0, 1); + Add(gte0, gte1, {0}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicBroadcast) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5, 4}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + BroadcastInDim(gte, /*out_dim_size=*/{3, 5, 4}, + /*broadcast_dimensions=*/{1, 2}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {false, true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicPad) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5, 4}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto pad_val = ConstantR0(&b, -1); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + PaddingConfig padding_config; + for (int i = 0; i < 2; i++) { + auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_low(0); + dimension->set_edge_padding_high(0); + dimension->set_interior_padding(0); + } + Pad(gte, pad_val, padding_config); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicConvolution) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {1, 2, 2, 128}), + ShapeUtil::MakeShape(F32, {2, 2, 128, 8}), ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{3}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/2)); + auto input = GetTupleElement(p0, 0); + auto filter = GetTupleElement(p0, 1); + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/1); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), + {true, false, false, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicDot) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2, 3, 4}), + ShapeUtil::MakeShape(F32, {2, 4, 5}), ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{3}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/1)); + + auto lhs = GetTupleElement(p0, 0); + auto rhs = GetTupleElement(p0, 1); + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(1); + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + DotGeneral(lhs, rhs, dnums); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {true, true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicReduce) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5, 4, 3}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto init = ConstantR0(&b, 0); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/1)); + auto gte = GetTupleElement(p0, 0); + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"), + Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y")); + TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + Reduce(gte, init, sum, {0}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicReduceWindow) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2, 4, 8}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto init = ConstantR0(&b, 0.f); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"), + Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y")); + TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + ReduceWindow(gte, init, sum, /*window_dimensions=*/{1, 2, 4}, + /*window_strides=*/{1, 1, 1}, Padding::kValid); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {true, false, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicSelectAndScatter) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2, 4, 8}), + ShapeUtil::MakeShape(F32, {2, 2, 2}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto init = ConstantR0(&b, 0.f); + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"), + Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y")); + TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + XlaBuilder bge(TestName()); + Ge(Parameter(&bge, 0, ShapeUtil::MakeShape(F32, {}), "x"), + Parameter(&bge, 1, ShapeUtil::MakeShape(F32, {}), "y")); + TF_ASSERT_OK_AND_ASSIGN(auto ge, bge.Build()); + + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto source = GetTupleElement(p0, 1); + SelectAndScatter(gte0, ge, {1, 2, 4}, {1, 2, 4}, Padding::kValid, source, + init, sum); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {true, false, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicReshape) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6}), + ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/2)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/3)); + auto gte = GetTupleElement(p0, 0); // f32[2, 3, <=4, <=5, 6] + Reshape(gte, /*new_sizes=*/{6, 4, 1, 5, 2, 3}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(result_shape.is_dynamic_dimension(1)); + EXPECT_TRUE(result_shape.is_dynamic_dimension(3)); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), + {false, true, false, true, false, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicSelect) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {4, 5, 6}), + ShapeUtil::MakeShape(F32, {4, 5, 6}), ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto pred = Parameter(&b, 1, ShapeUtil::MakeShape(PRED, {}), "pred"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/1)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{3}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/1)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p0, 1); + Select(pred, gte0, gte1); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(result_shape.is_dynamic_dimension(1)); + EXPECT_FALSE(result_shape.is_dynamic_dimension(2)); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {false, true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicSelectNotCompatible) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {4, 5, 6}), + ShapeUtil::MakeShape(F32, {4, 5, 6}), ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto pred = Parameter(&b, 1, ShapeUtil::MakeShape(PRED, {}), "pred"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/1)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{3}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/2)); + auto gte0 = GetTupleElement(p0, 0); // f32[4,<=5,6] + auto gte1 = GetTupleElement(p0, 1); // f32[4,5,<=6] + Select(pred, gte0, gte1); + Status status = BuildHloModule(&b).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Operands to select must be the same shape; " + "got f32[4,<=5,6] and f32[4,5,<=6]")); +} + +TEST_F(XlaBuilderTest, DynamicTranspose) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 5}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + Transpose(gte, /*permutation=*/{1, 0}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {false, true})) + << result_shape; +} + TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) { XlaBuilder b(TestName()); AfterAll(&b, {CreateToken(&b), ConstantR0(&b, 1.0)}); @@ -455,5 +869,31 @@ TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) { ::testing::HasSubstr("All operands to AfterAll must be tokens")); } +TEST_F(XlaBuilderTest, CheckInputOutputAlias) { + XlaBuilder b(TestName()); + auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {8, 4}), "p0"); + auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {8, 4}), "p1"); + auto add = Add(p0, p1); + auto sub = Sub(p0, p1); + auto root = Tuple(&b, {add, sub}); + + b.SetUpAlias({1}, 0, {}); + b.SetUpAlias({0}, 1, {}); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, root)); + + const HloInputOutputAliasConfig& config = module->input_output_alias_config(); + EXPECT_TRUE(config.ParameterHasAlias(0, {})); + EXPECT_TRUE(config.ParameterHasAlias(1, {})); + + auto alias_p0 = config.GetAliasedOutput(0, {}); + ASSERT_TRUE(alias_p0.has_value()); + EXPECT_EQ(*alias_p0, ShapeIndex({1})); + + auto alias_p1 = config.GetAliasedOutput(1, {}); + ASSERT_TRUE(alias_p1.has_value()); + EXPECT_EQ(*alias_p1, ShapeIndex({0})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/error_spec.h b/tensorflow/compiler/xla/error_spec.h index a1463aa15941b9c265db94e2eb3cc176fab6695b..4359f3b7deb8e585494cb2a9c7115eac6a312c8e 100644 --- a/tensorflow/compiler/xla/error_spec.h +++ b/tensorflow/compiler/xla/error_spec.h @@ -30,6 +30,19 @@ struct ErrorSpec { // In effect, this allows the tested operation to produce incorrect results // for inputs outside its mathematical domain. bool relaxed_nans; + + // If this is true, then we treat each +/-inf in the actual result as + // equivalent to our choice of either +/-inf or the min/max floating-point + // value. + // + // If the expected result is +/-inf, the actual result must still be +/-inf. + // + // In effect, this allows the tested operation to overflow, so long as it's + // overflowing on "large" values. + // + // (We could have a symmetric more_infs_ok flag if necessary; right now it + // appears not to be.) + bool fewer_infs_ok = false; }; } // namespace xla diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index ba3217f31b55bd1428f67da6154a46c8bc304053..6f36d11dfb34eb27e79ea4ff797d35f80fb44b27 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -16,9 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ #define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ -// Pulls in the ::stream_executor -> ::xla::se namespace alias. -#include "tensorflow/compiler/xla/types.h" - // These classes are forward declared so that ExecutableRunOptions can be linked // into an XLA-compiled binary without having to link all of the pointed-to // objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't @@ -28,12 +25,6 @@ class Stream; class Platform; } // namespace stream_executor -namespace tensorflow { -namespace thread { -class ThreadPool; -} // namespace thread -} // namespace tensorflow - namespace Eigen { struct ThreadPoolDevice; } // namespace Eigen diff --git a/tensorflow/compiler/xla/g3doc/broadcasting.md b/tensorflow/compiler/xla/g3doc/broadcasting.md index 2870869a2cef13a9105b9dc9fa4d657834288f86..5c0525c1e9adf9f37d945170d05e7c18fa3d8852 100644 --- a/tensorflow/compiler/xla/g3doc/broadcasting.md +++ b/tensorflow/compiler/xla/g3doc/broadcasting.md @@ -168,7 +168,7 @@ consult the Broadcasting of a lower-rank array to a higher-rank array **and** broadcasting using degenerate dimensions can both be performed in the same binary operation. -For example, a vector of size 4 and an matrix of size 1x2 can be added together +For example, a vector of size 4 and a matrix of size 1x2 can be added together using broadcast dimensions value of (0): |1 2 3 4| + [5 6] // [5 6] is a 1x2 matrix, not a vector. @@ -176,7 +176,7 @@ using broadcast dimensions value of (0): First the vector is broadcast up to rank 2 (matrix) using the broadcast dimensions. The single value (0) in the broadcast dimensions indicates that dimension zero of the vector matches to dimension zero of the matrix. This -produces an matrix of size 4xM where the value M is chosen to match the +produces a matrix of size 4xM where the value M is chosen to match the corresponding dimension size in the 1x2 array. Therefore, a 4x2 matrix is produced: diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 9a9cd08c301502cbda8858225182d95fca4bf7ae..c5f9377f98868cdf6d5c711cf80ede5d41fd8305 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -636,11 +636,15 @@ details, see `tf.nn.depthwise_conv2d`. The `batch_group_count` (default value 1) argument can be used for depthwise filters during backpropagation. `batch_group_count` needs to be a divisor of the -size of the `lhs` batch dimension. If `batch_group_count` is greater than 1, it -means that conceptually the output batch dimension is split evenely in -`batch_group_count` groups, such that each group consists of a consecutive -subsequence of batches. Each output batch element is the reduced value of the -batch group size. +size of the `lhs` (input) batch dimension. If `batch_group_count` is greater +than 1, it means that the output batch dimension should be of size +`batch_group_size` where `batch_group_size = input batch / batch_group_count`. +For convolutions with `batch_group_count` greater than 1, the input batch size +must evenly divide into batch_group_size and output feature size, which implies +that the output feature size must be equal to batch_group_count. Conceptually, +this can be achieved by performing the usual convolution, and then scraping +`batch_group_size` number of elements on the diagonal of the matrix formed by +output batch and output feature. The output shape has these dimensions, in this order: @@ -871,9 +875,7 @@ DotGeneral performs the sum of products over contracting dimensions specified in 'dimension_numbers'. Associated contracting dimension numbers from the 'lhs' and 'rhs' do not need -to be the same, but must be listed in the same order in both -'lhs/rhs_contracting_dimensions' arrays and have the same dimension sizes. -There must be exactly one contracting dimension on both 'lhs' and 'rhs'. +to be the same and but must have the same dimension sizes. Example with contracting dimension numbers: @@ -892,10 +894,8 @@ DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0}, {15.0, 30.0} } ``` -Associated batch dimension numbers from the 'lhs' and 'rhs' must have the same -dimension number, must be listed in the same order in both arrays, must -have the same dimension sizes, and must be ordered before contracting and -non-contracting/non-batch dimension numbers. +Associated batch dimension numbers from the 'lhs' and 'rhs' must +have the same dimension sizes. Example with batch dimension numbers (batch size 2, 2x2 matrices): @@ -944,21 +944,21 @@ dimension: [start, start + size). The shape of `start_indices` must be rank == `DynamicSlice(operand, start_indices, size_indices)` -| Arguments | Type | Semantics | -| --------------- | ------------------- | ----------------------------------- | -| `operand` | `XlaOp` | N dimensional array of type T | -| `start_indices` | `XlaOp` | Rank 1 array of N integers | -: : : containing the starting indices of : -: : : the slice for each dimension. Value : -: : : must be greater than or equal to : -: : : zero. : -| `size_indices` | `ArraySlice` | List of N integers containing the | -: : : slice size for each dimension. Each : -: : : value must be strictly greater than : -: : : zero, and start + size must be less : -: : : than or equal to the size of the : -: : : dimension to avoid wrapping modulo : -: : : dimension size. : +| Arguments | Type | Semantics | +| --------------- | --------------------- | ---------------------------------- | +| `operand` | `XlaOp` | N dimensional array of type T | +| `start_indices` | sequence of N `XlaOp` | List of N scalar integers | +: : : containing the starting indices of : +: : : the slice for each dimension. : +: : : Value must be greater than or : +: : : equal to zero. : +| `size_indices` | `ArraySlice` | List of N integers containing the | +: : : slice size for each dimension. : +: : : Each value must be strictly : +: : : greater than zero, and start + : +: : : size must be less than or equal to : +: : : the size of the dimension to avoid : +: : : wrapping modulo dimension size. : The effective slice indices are computed by applying the following transformation for each index `i` in `[1, N)` before performing the slice: @@ -1009,19 +1009,22 @@ the rank of `operand`. `DynamicUpdateSlice(operand, update, start_indices)` -| Arguments | Type | Semantics | -| --------------- | ------- | ------------------------------------------------ | -| `operand` | `XlaOp` | N dimensional array of type T | -| `update` | `XlaOp` | N dimensional array of type T containing the | -: : : slice update. Each dimension of update shape : -: : : must be strictly greater than zero, and start + : -: : : update must be less than or equal to the operand : -: : : size for each dimension to avoid generating : -: : : out-of-bounds update indices. : -| `start_indices` | `XlaOp` | Rank 1 array of N integers containing the | -: : : starting indices of the slice for each : -: : : dimension. Value must be greater than or equal : -: : : to zero. : +| Arguments | Type | Semantics | +| --------------- | --------------------- | ---------------------------------- | +| `operand` | `XlaOp` | N dimensional array of type T | +| `update` | `XlaOp` | N dimensional array of type T | +: : : containing the slice update. Each : +: : : dimension of update shape must be : +: : : strictly greater than zero, and : +: : : start + update must be less than : +: : : or equal to the operand size for : +: : : each dimension to avoid generating : +: : : out-of-bounds update indices. : +| `start_indices` | sequence of N `XlaOp` | List of N scalar integers | +: : : containing the starting indices of : +: : : the slice for each dimension. : +: : : Value must be greater than or : +: : : equal to zero. : The effective slice indices are computed by applying the following transformation for each index `i` in `[1, N)` before performing the slice: @@ -1095,7 +1098,7 @@ When `Op` is `Rem`, the sign of the result is taken from the dividend, and the absolute value of the result is always less than the divisor's absolute value. Integer division overflow (signed/unsigned division/remainder by zero or signed -divison/remainder of `INT_SMIN` with `-1`) produces an implementation defined +division/remainder of `INT_SMIN` with `-1`) produces an implementation defined value. An alternative variant with different-rank broadcasting support exists for these diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 2a0241af3ef359c4d1c6c1ab9319b5b293110f7a..7e22a32e545e4155545ffcfb9582187eadec3a82 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -141,7 +141,7 @@ namespace xla { /* static */ bool IndexUtil::IndexInBounds(const Shape& shape, absl::Span index) { - int64 rank = ShapeUtil::Rank(shape); + int64 rank = shape.rank(); if (rank != index.size()) { return false; } diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index ddccd8c798df5b926d2e5aea8975cb6cb6640824..2fe9b56c6bdffb931726f60ab75081361b43ebb4 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -101,13 +101,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } // namespace /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) { - if (ShapeUtil::IsOpaque(shape) || ShapeUtil::IsToken(shape)) { + if (shape.IsOpaque() || shape.IsToken()) { // Opaque and token types have empty layouts. return Layout(); } // A Layout proto corresponds to a single array, not a tuple. - CHECK(ShapeUtil::IsArray(shape)); + CHECK(shape.IsArray()); return CreateDefaultLayoutForRank(shape.dimensions_size()); } @@ -128,13 +128,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ void LayoutUtil::SetToDefaultLayout(Shape* shape) { - if (ShapeUtil::IsTuple(*shape)) { + if (shape->IsTuple()) { // Tuple shape. for (auto& element_shape : *shape->mutable_tuple_shapes()) { SetToDefaultLayout(&element_shape); } shape->clear_layout(); - } else if (ShapeUtil::IsArray(*shape)) { + } else if (shape->IsArray()) { shape->mutable_layout()->set_format(DENSE); auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); minor_to_major->resize(shape->dimensions_size(), 0); @@ -160,7 +160,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { /* static */ Status LayoutUtil::ValidateLayoutInShape( const Shape& shape, bool allow_missing_layouts) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { // Tuple shape. if (shape.has_layout()) { return InvalidArgument("tuple should not have a layout field"); @@ -170,7 +170,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { ValidateLayoutInShape(element_shape, allow_missing_layouts)); } return Status::OK(); - } else if (ShapeUtil::IsArray(shape)) { + } else if (shape.IsArray()) { if (!shape.has_layout()) { if (allow_missing_layouts) { return Status::OK(); @@ -192,11 +192,11 @@ Layout CreateDefaultLayoutForRank(int64 rank) { /* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout, const Shape& shape) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { return InvalidArgument("a single Layout is not valid for tuple shapes"); } - if (!ShapeUtil::IsArray(shape)) { + if (!shape.IsArray()) { if (layout.minor_to_major_size() != 0) { return InvalidArgument( "shape of primitive type %s should not have a non-trivial layout", @@ -211,19 +211,19 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } if (layout.format() == DENSE) { - if (layout.minor_to_major_size() != ShapeUtil::Rank(shape)) { + if (layout.minor_to_major_size() != shape.rank()) { return InvalidArgument( "layout minor_to_major field contains %d elements, " "but shape is rank %d: {%s}; shape: %s", - layout.minor_to_major_size(), ShapeUtil::Rank(shape), + layout.minor_to_major_size(), shape.rank(), absl::StrJoin(layout.minor_to_major(), ", "), shape.ShortDebugString()); } - std::vector dimensions_in_layout(ShapeUtil::Rank(shape), false); - for (int64 i = 0; i < ShapeUtil::Rank(shape); ++i) { + std::vector dimensions_in_layout(shape.rank(), false); + for (int64 i = 0; i < shape.rank(); ++i) { int64 dim = layout.minor_to_major(i); - if (dim < 0 || dim >= ShapeUtil::Rank(shape)) { + if (dim < 0 || dim >= shape.rank()) { return InvalidArgument( "layout minor_to_major field has out-of-bounds value: %s", HumanString(layout)); @@ -255,8 +255,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::IsDenseArray(const Shape& shape) { - return ShapeUtil::IsArray(shape) && shape.has_layout() && - IsDense(shape.layout()); + return shape.IsArray() && shape.has_layout() && IsDense(shape.layout()); } /* static */ bool LayoutUtil::IsDense(const Layout& layout) { @@ -276,8 +275,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) { - return ShapeUtil::IsArray(shape) && shape.has_layout() && - IsSparse(shape.layout()); + return shape.IsArray() && shape.has_layout() && IsSparse(shape.layout()); } /* static */ bool LayoutUtil::IsSparse(const Layout& layout) { @@ -290,11 +288,11 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::HasLayout(const Shape& shape) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { // Tuple shape: all subshapes must have a layout. - return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), - [](const Shape& s) { return HasLayout(s); }); - } else if (!ShapeUtil::IsArray(shape)) { + return absl::c_all_of(shape.tuple_shapes(), + [](const Shape& s) { return HasLayout(s); }); + } else if (!shape.IsArray()) { // Opaque, token types etc. ignore layout. return true; } @@ -360,11 +358,11 @@ namespace { // Internal helper for recursively copying layouts. Status CopyLayoutInternal(const Shape& src, Shape* dst) { - if (ShapeUtil::IsTuple(src) != ShapeUtil::IsTuple(*dst)) { + if (src.IsTuple() != dst->IsTuple()) { return InvalidArgument( "cannot copy layout from shape: shape structure differs"); } - if (ShapeUtil::IsTuple(src)) { + if (src.IsTuple()) { if (ShapeUtil::TupleElementCount(src) != ShapeUtil::TupleElementCount(*dst)) { return InvalidArgument( @@ -376,7 +374,7 @@ Status CopyLayoutInternal(const Shape& src, Shape* dst) { } } else { if (src.has_layout()) { - if (ShapeUtil::Rank(src) != ShapeUtil::Rank(*dst)) { + if (src.rank() != dst->rank()) { return InvalidArgument("cannot copy layout from shape: ranks differs"); } TF_RETURN_IF_ERROR( @@ -398,9 +396,9 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) { - if (ShapeUtil::IsTuple(lhs)) { - if (!ShapeUtil::IsTuple(rhs) || ShapeUtil::TupleElementCount(lhs) != - ShapeUtil::TupleElementCount(rhs)) { + if (lhs.IsTuple()) { + if (!rhs.IsTuple() || ShapeUtil::TupleElementCount(lhs) != + ShapeUtil::TupleElementCount(rhs)) { return false; } for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) { @@ -409,8 +407,8 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { } } return true; - } else if (ShapeUtil::IsArray(lhs)) { - return ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) && + } else if (lhs.IsArray()) { + return lhs.rank() == rhs.rank() && LayoutUtil::Equal(lhs.layout(), rhs.layout()); } else { // Layouts of non-array and non-tuple shapes is ignored. @@ -426,7 +424,7 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { positions_in_layout.push_back( PositionInContainer(layout.minor_to_major(), dim)); } - std::sort(positions_in_layout.begin(), positions_in_layout.end()); + absl::c_sort(positions_in_layout); for (size_t i = 1; i < positions_in_layout.size(); ++i) { if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) { return false; diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 277c98721e59ac12965392500fdfdc3d91e59a8b..8600e8752cfbe072407391559d210d0b49bea511 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -29,10 +29,12 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -107,7 +109,7 @@ Literal::Literal(const Shape& shape) : Literal(shape, /*allocate_arrays=*/true) {} void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { const Shape& subshape = shape.tuple_shapes(i); @@ -118,7 +120,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { piece->emplace_back(std::move(child_piece)); } - } else if (ShapeUtil::IsArray(shape)) { + } else if (shape.IsArray()) { if (allocate_arrays) { if (LayoutUtil::IsSparseArray(shape)) { // For sparse arrays, the buffer must be of the size of the maximum @@ -129,7 +131,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { new char[max_sparse_elements * ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]); piece->set_sparse_indices( - new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape))); + new SparseIndexArray(max_sparse_elements, shape.rank())); } else { piece->set_buffer(new char[piece->size_bytes()]); } @@ -187,7 +189,7 @@ Literal LiteralBase::CreateFromShape(const Shape& shape) { Literal literal(shape); literal.root_piece_->ForEachMutableSubpiece( [&](const ShapeIndex& index, Piece* piece) { - if (ShapeUtil::IsArray(piece->subshape())) { + if (piece->subshape().IsArray()) { memset(piece->untyped_data(), 0, piece->size_bytes()); } }); @@ -208,16 +210,15 @@ template Status MutableLiteralBase::CopySliceFromInternal( const LiteralBase& src_literal, absl::Span src_base, absl::Span dest_base, absl::Span copy_size) { - TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); - TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size()); + TF_RET_CHECK(src_literal.shape().rank() == src_base.size()); + TF_RET_CHECK(shape().rank() == dest_base.size()); auto linear_index = [](const Shape& shape, absl::Span multi_index) { return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index); }; - if (ShapeUtil::Rank(src_literal.shape()) == 0 || - ShapeUtil::Rank(shape()) == 0) { + if (src_literal.shape().rank() == 0 || shape().rank() == 0) { // If any of the two shapes are scalars, we can just call the StridedCopy() // directly, and we know we will be copying only one value. TF_RET_CHECK(copy_size.empty()); @@ -312,7 +313,7 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, proto_element = &proto_element->tuple_literals(i); } - if (ShapeUtil::IsTuple(piece->subshape())) { + if (piece->subshape().IsTuple()) { if (proto_element->tuple_literals_size() != ShapeUtil::TupleElementCount(piece->subshape())) { return InvalidArgument( @@ -326,7 +327,7 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, return Status::OK(); } - CHECK(ShapeUtil::IsArray(piece->subshape())); + CHECK(piece->subshape().IsArray()); TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); return Status::OK(); @@ -336,7 +337,7 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, } std::vector Literal::DecomposeTuple() { - CHECK(ShapeUtil::IsTuple(shape())); + CHECK(shape().IsTuple()); std::vector elements; for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}), @@ -375,7 +376,7 @@ void CopyElementsBetween(absl::Span dest, if (ShapeUtil::IsZeroElementArray(dest_shape)) { return; } - std::vector index(ShapeUtil::Rank(dest_shape)); + std::vector index(dest_shape.rank()); do { dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] = src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)]; @@ -392,7 +393,7 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { memcpy(buffer(), src.buffer(), src.size_bytes()); } else { TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape())); - std::vector origin(ShapeUtil::Rank(subshape()), 0); + std::vector origin(subshape().rank(), 0); switch (subshape().element_type()) { #define COPY_ELEMENTS(XLA_T, NATIVE_T) \ case (XLA_T): \ @@ -412,6 +413,7 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { COPY_ELEMENTS(F32, float); COPY_ELEMENTS(F64, double); COPY_ELEMENTS(C64, complex64); + COPY_ELEMENTS(C128, complex128); COPY_ELEMENTS(PRED, bool); #undef COPY_ELEMENTS default: @@ -438,7 +440,7 @@ Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal, } return root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { - if (!ShapeUtil::IsArray(piece->subshape())) { + if (!piece->subshape().IsArray()) { return Status::OK(); } @@ -477,7 +479,7 @@ Status Literal::MoveFrom(Literal&& src_literal, src_literal.root_piece_->ForEachSubpiece( [&](const ShapeIndex& src_index, const Piece& src_piece) { - if (!ShapeUtil::IsArray(src_piece.subshape())) { + if (!src_piece.subshape().IsArray()) { return; } @@ -504,8 +506,8 @@ Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal, absl::Span src_base, absl::Span dest_base, absl::Span copy_size) { - TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape()); - TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape())) + TF_RET_CHECK(shape().IsArray()) << ShapeUtil::HumanString(shape()); + TF_RET_CHECK(src_literal.shape().IsArray()) << ShapeUtil::HumanString(src_literal.shape()); TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape())); @@ -549,6 +551,9 @@ Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal, case C64: return CopySliceFromInternal(src_literal, src_base, dest_base, copy_size); + case C128: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case PRED: return CopySliceFromInternal(src_literal, src_base, dest_base, copy_size); @@ -562,8 +567,8 @@ Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal, } void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK(shape().IsArray()); + CHECK_EQ(shape().rank(), 1); CHECK_EQ(element_count(), values.bits()); CHECK_EQ(shape().element_type(), PRED); for (int64 i = 0; i < static_cast(values.bits()); ++i) { @@ -592,7 +597,7 @@ Literal LiteralBase::Relayout(const Shape& shape_with_layout) const { ShapeUtil::ForEachSubshape( result.shape(), [this, &result](const Shape& subshape, const ShapeIndex& index) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { TF_CHECK_OK(result.CopyFrom(*this, /*dest_shape_index=*/index, /*src_shape_index=*/index)); @@ -603,7 +608,7 @@ Literal LiteralBase::Relayout(const Shape& shape_with_layout) const { StatusOr LiteralBase::Broadcast( const Shape& result_shape, absl::Span dimensions) const { - if (!ShapeUtil::IsArray(shape())) { + if (!shape().IsArray()) { return InvalidArgument("Broadcast only supports arrays."); } @@ -643,13 +648,12 @@ StatusOr LiteralBase::Broadcast( StatusOr LiteralBase::Reshape( absl::Span dimensions) const { - if (!ShapeUtil::IsArray(shape())) { + if (!shape().IsArray()) { return InvalidArgument("Reshape does not support tuples."); } Literal output; if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { - output = - Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); + output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank())); } else { output = Clone(); } @@ -671,8 +675,8 @@ StatusOr LiteralBase::Reshape( } Literal LiteralBase::Transpose(absl::Span permutation) const { - CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; - CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) + CHECK(shape().IsArray()) << "Tuple is not supported for transpose"; + CHECK(IsPermutation(permutation, shape().rank())) << "Given permutation is not a permutation of dimension numbers"; // To transpose the array, we just permute the dimensions and layout, and // do a straight memory copy of the raw data set. @@ -711,10 +715,10 @@ template Literal LiteralBase::SliceInternal( const Shape& result_shape, absl::Span start_indices) const { Literal result_literal(result_shape); - DimensionVector new_indices(ShapeUtil::Rank(result_shape)); + DimensionVector new_indices(result_shape.rank()); result_literal.EachCell( [&](absl::Span indices, NativeT /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { + for (int64 i = 0; i < result_shape.rank(); ++i) { new_indices[i] = indices[i] + start_indices[i]; } NativeT value = Get(new_indices); @@ -725,10 +729,10 @@ Literal LiteralBase::SliceInternal( Literal LiteralBase::Slice(absl::Span start_indices, absl::Span limit_indices) const { - CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; + CHECK(shape().IsArray()) << "tuple is not supported for slice"; DimensionVector result_dimensions; - for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) { + for (int64 dnum = 0; dnum < shape().rank(); ++dnum) { CHECK_GE(start_indices[dnum], 0); CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)) << "dnum = " << dnum; @@ -768,6 +772,8 @@ Literal LiteralBase::Slice(absl::Span start_indices, return SliceInternal(result_shape, start_indices); case C64: return SliceInternal(result_shape, start_indices); + case C128: + return SliceInternal(result_shape, start_indices); default: LOG(FATAL) << "not yet implemented: " << PrimitiveType_Name(result_shape.element_type()); @@ -816,6 +822,10 @@ string LiteralBase::GetAsString(absl::Span multi_index, complex64 c = Get(multi_index, shape_index); return StrCat("(", c.real(), ", ", c.imag(), ")"); } + case C128: { + complex128 c = Get(multi_index, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); + } default: LOG(FATAL) << PrimitiveType_Name(subshape.element_type()); } @@ -870,6 +880,11 @@ string LiteralBase::GetSparseElementAsString( GetSparseElement(sparse_element_number, shape_index); return StrCat("(", c.real(), ", ", c.imag(), ")"); } + case C128: { + complex128 c = + GetSparseElement(sparse_element_number, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); + } default: LOG(FATAL) << "Invalid element type for sparse arrays: " << PrimitiveType_Name(subshape.element_type()); @@ -906,7 +921,7 @@ size_t LiteralBase::Hash() const { ShapeUtil::ForEachSubshape( shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (!ShapeUtil::IsArray(subshape)) { + if (!subshape.IsArray()) { return; } @@ -998,6 +1013,9 @@ void LiteralBase::Piece::SortSparseElements() { case C64: SortSparseElementsInternal(); break; + case C128: + SortSparseElementsInternal(); + break; case F16: SortSparseElementsInternal(); break; @@ -1056,7 +1074,7 @@ void SparseArrayToStringHelper(const LiteralBase& literal, pieces->push_back(ShapeToString(print_layout, subshape)); } pieces->push_back("{"); - int64 rank = ShapeUtil::Rank(subshape); + int64 rank = subshape.rank(); int64 num_elements = literal.sparse_element_count(); for (int64 i = 0; i < num_elements; ++i) { if (i > 0) { @@ -1079,7 +1097,7 @@ void DenseArrayToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, bool print_shape, bool print_layout, std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); - int64 rank = ShapeUtil::Rank(subshape); + int64 rank = subshape.rank(); std::function dimensions, std::vector*)> to_string_recursive = [&](absl::Span dimensions, @@ -1154,10 +1172,10 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); CHECK(LayoutUtil::HasLayout(literal.shape())); CHECK(LayoutUtil::HasLayout(subshape)); - if (ShapeUtil::IsTuple(subshape)) { + if (subshape.IsTuple()) { TupleToStringHelper(literal, shape_index, print_shape, print_layout, pieces); - } else if (ShapeUtil::IsToken(subshape)) { + } else if (subshape.IsToken()) { pieces->push_back("token"); } else if (LayoutUtil::IsSparseArray(subshape)) { SparseArrayToStringHelper(literal, subshape, print_shape, print_layout, @@ -1217,7 +1235,7 @@ namespace { template Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal, const ConverterType& converter) { - CHECK(ShapeUtil::IsArray(src_literal.shape())); + CHECK(src_literal.shape().IsArray()); Literal result_literal(ShapeUtil::ChangeElementType( src_literal.shape(), primitive_util::NativeToPrimitiveType())); @@ -1232,7 +1250,24 @@ Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal, } template -Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) { +typename std::enable_if<(std::is_same::value) && + (std::is_same::value || + std::is_same::value), + Literal>::type +ConvertBetweenNativeTypes(const LiteralBase& src_literal) { + auto converter = [](NativeSrcT src) { + return NativeDestT(static_cast(src)); + }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); +} + +template +typename std::enable_if<(!std::is_same::value) || + (!std::is_same::value && + !std::is_same::value), + Literal>::type +ConvertBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return static_cast(src); }; return ConvertBetweenNativeTypesWithConverter( src_literal, converter); @@ -1276,22 +1311,6 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) { LOG(FATAL) << "Invalid bitcast between types of different sizes."; } -template -Literal ConvertToC64(const LiteralBase& src_literal) { - CHECK(ShapeUtil::IsArray(src_literal.shape())); - Literal result_literal( - ShapeUtil::ChangeElementType(src_literal.shape(), C64)); - using NativeSrcT = - typename primitive_util::PrimitiveTypeToNative::type; - absl::Span src_data = src_literal.data(); - absl::Span dest_data = result_literal.data(); - int64 num_elements = src_literal.element_count(); - for (int64 i = 0; i < num_elements; ++i) { - dest_data[i] = complex64(static_cast(src_data[i]), 0); - } - return result_literal; -} - template Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); @@ -1321,9 +1340,11 @@ StatusOr ConvertIfDestTypeMatches(const LiteralBase& src_literal, bitcast); CONVERT_IF_TYPES_MATCH(PRED) CONVERT_IF_TYPES_MATCH(S8) + CONVERT_IF_TYPES_MATCH(S16) CONVERT_IF_TYPES_MATCH(S32) CONVERT_IF_TYPES_MATCH(S64) CONVERT_IF_TYPES_MATCH(U8) + CONVERT_IF_TYPES_MATCH(U16) CONVERT_IF_TYPES_MATCH(U32) CONVERT_IF_TYPES_MATCH(U64) CONVERT_IF_TYPES_MATCH(F16) @@ -1332,10 +1353,15 @@ StatusOr ConvertIfDestTypeMatches(const LiteralBase& src_literal, CONVERT_IF_TYPES_MATCH(BF16) #undef CONVERT_IF_TYPES_MATCH case C64: - if (!bitcast) { - return ConvertToC64(src_literal); + if (bitcast) { + break; } - break; + return ConvertIfTypesMatch(src_literal, false); + case C128: + if (bitcast) { + break; + } + return ConvertIfTypesMatch(src_literal, false); // Other types are not yet supported. default: break; @@ -1348,7 +1374,7 @@ StatusOr ConvertIfDestTypeMatches(const LiteralBase& src_literal, StatusOr ConvertSwitch(const LiteralBase& literal, PrimitiveType primitive_dest_type, bool bitcast) { - TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); + TF_RET_CHECK(literal.shape().IsArray()); if (literal.shape().element_type() == primitive_dest_type) { return literal.Clone(); } @@ -1359,9 +1385,11 @@ StatusOr ConvertSwitch(const LiteralBase& literal, bitcast); CONVERT_IF_DEST_TYPE_MATCHES(PRED) CONVERT_IF_DEST_TYPE_MATCHES(S8) + CONVERT_IF_DEST_TYPE_MATCHES(S16) CONVERT_IF_DEST_TYPE_MATCHES(S32) CONVERT_IF_DEST_TYPE_MATCHES(S64) CONVERT_IF_DEST_TYPE_MATCHES(U8) + CONVERT_IF_DEST_TYPE_MATCHES(U16) CONVERT_IF_DEST_TYPE_MATCHES(U32) CONVERT_IF_DEST_TYPE_MATCHES(U64) CONVERT_IF_DEST_TYPE_MATCHES(F16) @@ -1401,7 +1429,7 @@ StatusOr LiteralBase::BitcastConvert( } StatusOr LiteralBase::ConvertToShape(const Shape& dest_shape) const { - if (!ShapeUtil::IsTuple(dest_shape)) { + if (!dest_shape.IsTuple()) { return Convert(dest_shape.element_type()); } std::vector elements; @@ -1433,7 +1461,7 @@ StatusOr LiteralBase::ConvertToShape(const Shape& dest_shape) const { template bool LiteralBase::Piece::EqualElementsInternal( const LiteralBase::Piece& other, std::vector* multi_index) const { - if (multi_index->size() == ShapeUtil::Rank(subshape())) { + if (multi_index->size() == subshape().rank()) { return (Get(*multi_index) == other.Get(*multi_index)); } for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) { @@ -1483,6 +1511,8 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { return EqualElementsInternal(other, &multi_index); case C64: return EqualElementsInternal(other, &multi_index); + case C128: + return EqualElementsInternal(other, &multi_index); default: LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type " << PrimitiveType_Name(subshape().element_type()); @@ -1496,7 +1526,7 @@ bool LiteralBase::operator==(const LiteralBase& other) const { return root_piece().ForEachSubpieceWithBool( [&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { + if (!piece.subshape().IsArray()) { return true; } @@ -1526,7 +1556,7 @@ static bool AllElementsEqualValue(absl::Span data, bool LiteralBase::IsAll(int8 value) const { return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { + if (!piece.subshape().IsArray()) { return true; } @@ -1594,7 +1624,7 @@ bool LiteralBase::IsAll(int8 value) const { bool LiteralBase::IsAllFloat(float value) const { return root_piece().ForEachSubpieceWithBool( [&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { + if (!piece.subshape().IsArray()) { return true; } @@ -1626,6 +1656,9 @@ bool LiteralBase::IsAllComplex(complex64 value) const { case C64: return AllElementsEqualValue(root_piece().data(), value); + case C128: + return AllElementsEqualValue(root_piece().data(), + value); default: return false; } @@ -1634,7 +1667,7 @@ bool LiteralBase::IsAllComplex(complex64 value) const { bool LiteralBase::IsAllFirst() const { return root_piece().ForEachSubpieceWithBool( [&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { + if (!piece.subshape().IsArray()) { return true; } @@ -1705,6 +1738,11 @@ bool LiteralBase::IsAllFirst() const { auto data = piece.data(); return AllElementsEqualValue(data, data[0]); } + + case C128: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } default: return false; } @@ -1718,11 +1756,11 @@ bool LiteralBase::IsAllFirst() const { } bool LiteralBase::IsR1Iota() const { - if (!ShapeUtil::IsArray(shape())) { + if (!shape().IsArray()) { return false; } - if (ShapeUtil::Rank(shape()) != 1) { + if (shape().rank() != 1) { return false; } @@ -1754,6 +1792,8 @@ bool LiteralBase::IsR1Iota() const { return Get({idx}) == static_cast(idx); case C64: return Get({idx}) == complex64(idx, 0.0f); + case C128: + return Get({idx}) == complex128(idx, 0.0f); case PRED: return Get({idx}) == idx; // token, opaque, tuple, etc. are all not iota. @@ -1773,7 +1813,7 @@ bool LiteralBase::IsR1Iota() const { } bool LiteralBase::IsZero(absl::Span indices) const { - CHECK(ShapeUtil::IsArray(shape())); + CHECK(shape().IsArray()); switch (shape().element_type()) { case U8: return Get(indices) == 0; @@ -1797,6 +1837,8 @@ bool LiteralBase::IsZero(absl::Span indices) const { return Get(indices) == 0.0; case C64: return Get(indices) == complex64(0.0f, 0.0f); + case C128: + return Get(indices) == complex128(0.0f, 0.0f); case F16: return Get(indices) == static_cast(0.0f); case BF16: @@ -1884,6 +1926,12 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { proto->add_c64s(value.imag()); } break; + case C128: + for (complex128 value : data()) { + proto->add_c128s(value.real()); + proto->add_c128s(value.imag()); + } + break; case TUPLE: case TOKEN: // Nothing to do but assign the shape which is done above. @@ -1896,12 +1944,12 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { } const void* LiteralBase::Piece::untyped_data() const { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape()); return buffer(); } void* LiteralBase::Piece::untyped_data() { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape()); return buffer(); } @@ -1932,14 +1980,12 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { if (LayoutUtil::IsSparseArray(subshape())) { // Compute the number of elements (indices) in the sparse shape and reserve // the necessary space in spare_indices. - TF_RET_CHECK(ShapeUtil::Rank(subshape()) != 0) - << "Scalar shapes cannot be sparse"; - TF_RET_CHECK(proto.sparse_indices_size() % ShapeUtil::Rank(subshape()) == 0) + TF_RET_CHECK(subshape().rank() != 0) << "Scalar shapes cannot be sparse"; + TF_RET_CHECK(proto.sparse_indices_size() % subshape().rank() == 0) << "Unexpected number of indices in proto (" << proto.sparse_indices_size() << ") for shape of rank " - << ShapeUtil::Rank(subshape()); - const int64 index_count = - proto.sparse_indices_size() / ShapeUtil::Rank(subshape()); + << subshape().rank(); + const int64 index_count = proto.sparse_indices_size() / subshape().rank(); sparse_indices()->Resize(index_count); // Copy the indices from the proto into the SparseIndexArray object. @@ -2018,7 +2064,17 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { for (int64 i = 0; i < complex_data.size(); ++i) { complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)}; } - } break; + break; + } + case C128: { + auto complex_data = data(); + TF_RET_CHECK(proto.c128s_size() == complex_data.size() * 2); + for (int64 i = 0; i < complex_data.size(); ++i) { + complex_data[i] = + complex128{proto.c128s(i * 2), proto.c128s(i * 2 + 1)}; + } + break; + } case TUPLE: return InvalidArgument("Should not be called on tuple shapes: %s", ShapeUtil::HumanString(subshape())); @@ -2064,8 +2120,8 @@ int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const { } string LiteralBase::GetR1U8AsString() const { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK(shape().IsArray()); + CHECK_EQ(shape().rank(), 1); CHECK_EQ(shape().element_type(), U8); return string(absl::bit_cast(data().data()), ShapeUtil::ElementsIn(shape())); @@ -2079,7 +2135,7 @@ void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape, << ShapeUtil::HumanString(src_piece->subshape()) << "dest_piece has shape: " << ShapeUtil::HumanString(dest_piece->subshape()); - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { const Shape& subshape = shape.tuple_shapes(i); @@ -2090,7 +2146,7 @@ void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape, dest_piece->emplace_back(std::move(child_piece)); } - } else if (ShapeUtil::IsArray(shape)) { + } else if (shape.IsArray()) { dest_piece->set_buffer(src_piece->buffer()); } else { // If the shape is neither an array nor tuple, then it must be @@ -2166,7 +2222,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr, : MutableLiteralBase() { shape_ = absl::make_unique(shape); CHECK(LayoutUtil::HasLayout(*shape_)); - CHECK(!ShapeUtil::IsTuple(*shape_)); + CHECK(!shape_->IsTuple()); root_piece_ = new Piece(); root_piece_->set_buffer(const_cast(src_buf_ptr)); @@ -2193,14 +2249,14 @@ LiteralSlice::LiteralSlice(const LiteralBase& literal, : LiteralBase(), root_piece_(&literal.piece(view_root)) {} void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { - CHECK(ShapeUtil::IsTuple(shape)); + CHECK(shape.IsTuple()); for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { const Shape& subshape = shape.tuple_shapes(i); auto child_piece = Piece(); child_piece.set_subshape(&subshape); - if (ShapeUtil::IsTuple(subshape)) { + if (subshape.IsTuple()) { BuildPieceSubtree(subshape, &child_piece); } @@ -2210,7 +2266,7 @@ void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) : LiteralBase(), shape_(absl::make_unique(shape)) { - CHECK(ShapeUtil::IsArray(*shape_)); + CHECK(shape_->IsArray()); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = Piece(); @@ -2221,7 +2277,7 @@ BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) BorrowingLiteral::BorrowingLiteral(absl::Span src_buf_ptrs, const Shape& shape) : LiteralBase(), shape_(absl::make_unique(shape)) { - CHECK(ShapeUtil::IsTuple(*shape_)); + CHECK(shape_->IsTuple()); CHECK(!ShapeUtil::IsNestedTuple(*shape_)); CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); root_piece_ = Piece(); @@ -2230,7 +2286,7 @@ BorrowingLiteral::BorrowingLiteral(absl::Span src_buf_ptrs, for (int i = 0; i < src_buf_ptrs.size(); ++i) { const auto& src_shape = shape_->tuple_shapes(i); - CHECK(ShapeUtil::IsArray(src_shape)); + CHECK(src_shape.IsArray()); root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); } } diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index 67e908e7ec4d4346f4e26a99a42aac26928ec0c2..041151fda1280d6ae7b35d5857ca79788d4f7203 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -867,7 +867,7 @@ class BorrowingLiteral : public LiteralBase { template absl::Span LiteralBase::Piece::data() const { - DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape()); DCHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) << "Attempting to access " @@ -880,7 +880,7 @@ absl::Span LiteralBase::Piece::data() const { template absl::Span LiteralBase::Piece::data() { - DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape()); DCHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) << "Attempting to access " @@ -961,7 +961,7 @@ void MutableLiteralBase::AppendSparseElement( Piece& p = piece(shape_index); const Shape& subshape = p.subshape(); CHECK(LayoutUtil::IsSparseArray(subshape)); - int64 rank = ShapeUtil::Rank(subshape); + int64 rank = subshape.rank(); CHECK_EQ(multi_index.size(), rank); int64 last_element = p.sparse_indices()->index_count(); CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout())); @@ -977,7 +977,7 @@ void LiteralBase::EachCell( if (ShapeUtil::IsZeroElementArray(shape())) { return; } - std::vector indices(ShapeUtil::Rank(shape()), 0); + std::vector indices(shape().rank(), 0); do { per_cell(indices, Get(indices)); } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices))); @@ -985,8 +985,8 @@ void LiteralBase::EachCell( template inline void MutableLiteralBase::PopulateR1(absl::Span values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK(shape().IsArray()); + CHECK_EQ(shape().rank(), 1); CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); @@ -997,8 +997,8 @@ inline void MutableLiteralBase::PopulateR1(absl::Span values) { template void MutableLiteralBase::PopulateR2( std::initializer_list> values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 2); + CHECK(shape().IsArray()); + CHECK_EQ(shape().rank(), 2); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); @@ -1021,10 +1021,10 @@ void MutableLiteralBase::PopulateR2( template void MutableLiteralBase::PopulateFromArray(const Array& values) { - CHECK(ShapeUtil::IsArray(shape())); + CHECK(shape().IsArray()); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); - CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions()); + CHECK_EQ(shape().rank(), values.num_dimensions()); for (int dim = 0; dim < values.num_dimensions(); ++dim) { CHECK_EQ(values.dim(dim), shape().dimensions(dim)); } @@ -1053,7 +1053,7 @@ void MutableLiteralBase::PopulateSparse(SparseIndexArray indices, absl::Span values, bool sort) { CHECK(LayoutUtil::IsSparseArray(shape())); - int rank = ShapeUtil::Rank(shape()); + int rank = shape().rank(); CHECK_EQ(indices.rank(), rank); int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout()); CHECK_LE(indices.max_indices(), max_elements); @@ -1077,7 +1077,7 @@ template Status MutableLiteralBase::PopulateInternal(const FnType& generator, bool parallel) { const Shape& this_shape = shape(); - const int64 rank = ShapeUtil::Rank(this_shape); + const int64 rank = this_shape.rank(); TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); TF_RET_CHECK(this_shape.element_type() == primitive_util::NativeToPrimitiveType()); @@ -1129,7 +1129,7 @@ Status MutableLiteralBase::PopulateParallel(const FnType& generator) { template void MutableLiteralBase::PopulateWithValue(NativeT value) { - CHECK(ShapeUtil::IsArray(shape())); + CHECK(shape().IsArray()); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); for (NativeT& element : data()) { diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 1ac9a48e805daa86f0dc65b54626195c89241020..9b3de75dd4e9d495778af86fb8fc07909ab4ba81 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -90,6 +90,12 @@ bool CompareEqual(complex64 lhs, complex64 rhs, return CompareEqual(lhs.real(), rhs.real(), multi_index) && CompareEqual(lhs.imag(), rhs.imag(), multi_index); } +template <> +bool CompareEqual(complex128 lhs, complex128 rhs, + absl::Span multi_index) { + return CompareEqual(lhs.real(), rhs.real(), multi_index) && + CompareEqual(lhs.imag(), rhs.imag(), multi_index); +} template Status MakeBitwiseErrorStatus(NativeT lhs, NativeT rhs, @@ -143,6 +149,14 @@ Status MakeErrorStatus(complex64 lhs, complex64 rhs, } return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index); } +template <> +Status MakeErrorStatus(complex128 lhs, complex128 rhs, + absl::Span multi_index) { + if (!CompareEqual(lhs.real(), rhs.real(), multi_index)) { + return MakeErrorStatus(lhs.real(), rhs.real(), multi_index); + } + return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index); +} // A recursive function which iterates through every index of expected and // actual literal and compares their values elementwise. Returns true if all @@ -172,53 +186,40 @@ Status Equal(LiteralSlice expected, LiteralSlice actual, // Gets the total element count. For tuples, this is not the count of tuple // elements, but the sum of elements of each tuple element. int64 RecursiveElementCount(const Shape& shape) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { const int64 tuple_elements = ShapeUtil::TupleElementCount(shape); int64 total = 0; for (int64 i = 0; i < tuple_elements; ++i) { total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); } return total; - } else if (ShapeUtil::IsArray(shape)) { + } else if (shape.IsArray()) { return ShapeUtil::ElementsIn(shape); } else { return 0; } } -// Returns whether the actual and expected values are mismatched with respect to -// nans. 'relaxed_nans' is interpreted as in xla::ErrorSpec. +// Returns whether the given value is infinity. template -bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) { - if (relaxed_nans) { - return !std::isnan(expected) && std::isnan(actual); - } else { - return std::isnan(expected) != std::isnan(actual); - } -} - -template <> -bool NanMismatch(complex64 expected, complex64 actual, - bool relaxed_nans) { - return NanMismatch(expected.real(), actual.real(), relaxed_nans) || - NanMismatch(expected.imag(), actual.imag(), relaxed_nans); +bool IsInf(NativeT val) { + return std::isinf(val); } template <> -bool NanMismatch(half expected, half actual, bool relaxed_nans) { - return NanMismatch(static_cast(expected), - static_cast(actual), relaxed_nans); +bool IsInf(half val) { + return std::isinf(static_cast(val)); } -// Returns whether the given value is infinity. +// Returns whether the given value is nan. template -bool IsInf(NativeT val) { - return std::isinf(val); +float IsNan(NativeT value) { + return std::isnan(value); } template <> -bool IsInf(half val) { - return std::isinf(static_cast(val)); +float IsNan(half value) { + return IsNan(static_cast(value)); } // Converts the given floating-point value to a string. @@ -232,6 +233,11 @@ string FpValueToString(complex64 value) { return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag()); } +template <> +string FpValueToString(complex128 value) { + return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag()); +} + // Returns the absolute value of the given floating point value. This function // is used instead of std::abs directly in order to allow type-dependent // implementations for NearComparator. @@ -311,7 +317,7 @@ class NearComparator { // If the shapes mismatch, we simply fail the expectation instead of // printing out data, as it's a type error rather than a value error. TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape())); - if (!ShapeUtil::IsArray(expected_.shape())) { + if (!expected_.shape().IsArray()) { return InvalidArgument("Expected array shape; got %s.", ShapeUtil::HumanString(expected_.shape())); } @@ -364,21 +370,39 @@ class NearComparator { // the given literal_index and keeps track of various mismatch statistics. template void CompareValues(T expected, T actual, int64 linear_index) { - const bool is_nan_mismatch = - NanMismatch(expected, actual, error_.relaxed_nans); float abs_error; float rel_error; if (CompareEqual(expected, actual, {linear_index})) { abs_error = 0; rel_error = 0; - } else if (is_nan_mismatch) { - num_nan_mismatches_++; - // A nan mismatch is considered to have infinite error. rel_error is used - // for sorting a std::set of the top mismatchs, and a nan value here will - // result in undefined behavior because nan's do not satisfy the strict - // weak ordering requirement of std containers. - abs_error = std::numeric_limits::infinity(); - rel_error = std::numeric_limits::infinity(); + } else if (IsNan(expected) || IsNan(actual)) { + if ((!error_.relaxed_nans && IsNan(expected) != IsNan(actual)) || + (error_.relaxed_nans && !IsNan(expected) && IsNan(actual))) { + num_nan_mismatches_++; + // A nan mismatch is considered to have infinite error. rel_error is + // used for sorting a std::set of the top mismatchs, and a nan value + // here will result in undefined behavior because nan's do not satisfy + // the strict weak ordering requirement of std containers. + abs_error = std::numeric_limits::infinity(); + rel_error = std::numeric_limits::infinity(); + } else { + abs_error = 0; + rel_error = 0; + } + } else if (IsInf(actual) && !IsInf(expected) && error_.fewer_infs_ok) { + // `fewer_infs_ok` gives us the option of comparing as though `actual` + // were float_max/min rather than inf. + T actual_finite = actual > T{0} ? std::numeric_limits::max() + : std::numeric_limits::lowest(); + abs_error = FpAbsoluteValue(actual_finite - expected); + + // Avoid division by 0 even though it's well-defined because ubsan can be + // configured to treat this as a fatal error. + if (expected != T{0}) { + rel_error = abs_error / FpAbsoluteValue(expected); + } else { + rel_error = std::numeric_limits::infinity(); + } } else if (IsInf(expected) || IsInf(actual)) { // If either the expected or actual value is infinity but not both, // then both absolute and relative error are regarded as inifity. @@ -387,12 +411,18 @@ class NearComparator { rel_error = std::numeric_limits::infinity(); } else { abs_error = FpAbsoluteValue(actual - expected); - rel_error = abs_error / FpAbsoluteValue(expected); + + // Avoid division by 0 even though it's well-defined because ubsan can be + // configured to treat this as a fatal error. + if (expected != T{0}) { + rel_error = abs_error / FpAbsoluteValue(expected); + } else { + rel_error = std::numeric_limits::infinity(); + } } const bool is_abs_mismatch = abs_error > error_.abs; const bool is_rel_mismatch = rel_error > error_.rel; - const bool is_mismatch = - is_nan_mismatch || (is_abs_mismatch && is_rel_mismatch); + const bool is_mismatch = is_abs_mismatch && is_rel_mismatch; // Update the error of the relative bucket only if the *absolute* error // bound is exceeded and vice versa. @@ -427,7 +457,7 @@ class NearComparator { mismatches_.data()[linear_index] = true; } - // For complex64 types, we compare real and imaginary parts individually. + // For complex types, we compare real and imaginary parts individually. void CompareValues(complex64 expected, complex64 actual, int64 linear_index) { bool mismatch = false; CompareValues(expected.real(), actual.real(), linear_index); @@ -450,6 +480,29 @@ class NearComparator { mismatches_.data()[linear_index] = mismatch; } + void CompareValues(complex128 expected, complex128 actual, + int64 linear_index) { + bool mismatch = false; + CompareValues(expected.real(), actual.real(), linear_index); + if (mismatches_.data()[linear_index] == true) { + mismatch = true; + // Delay the mismatch count increase for real part, instead increase + // mismatch by 1 for the entire complex number. + num_mismatches_--; + } + CompareValues(expected.imag(), actual.imag(), linear_index); + if (mismatches_.data()[linear_index] == true) { + mismatch = true; + // Delay the mismatch count increase for imag part, instead increase + // mismatch by 1 for the entire complex number. + num_mismatches_--; + } + if (mismatch == true) { + num_mismatches_++; + } + mismatches_.data()[linear_index] = mismatch; + } + // Compares the two literals elementwise. void CompareLiterals() { // Fast path optimization for the case were layouts match. @@ -463,7 +516,7 @@ class NearComparator { } return; } - std::vector multi_index(ShapeUtil::Rank(actual_.shape()), 0); + std::vector multi_index(actual_.shape().rank(), 0); CompareLiteralsSlow(0, &multi_index); } @@ -658,6 +711,9 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) { case C64: result = Equal(expected, actual, index, 0); break; + case C128: + result = Equal(expected, actual, index, 0); + break; case TUPLE: { for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { result.Update(EqualHelper(LiteralSlice(expected, {i}), @@ -680,12 +736,12 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) { // via recursion. shape_index is the ShapeIndex of expected (or actual) // currently being compared. Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error, bool detailed_message, + const ErrorSpec& error, absl::optional detailed_message, const MiscompareCallback& miscompare_callback, const ShapeIndex& shape_index) { TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); - if (ShapeUtil::IsTuple(expected.shape())) { + if (expected.shape().IsTuple()) { Status return_status; for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { const auto expected_element = LiteralSlice(expected, {i}); @@ -721,26 +777,32 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, if (ShapeUtil::ElementIsFloating(expected.shape()) || ShapeUtil::ElementIsComplex(expected.shape())) { + bool use_detailed_message = detailed_message.value_or( + ShapeUtil::ElementsIn(expected.shape()) >= 64); switch (expected.shape().element_type()) { case BF16: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case F16: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case F32: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case F64: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case C64: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); + break; + case C128: + return NearComparator::Compare( + expected, actual, error, use_detailed_message, miscompare_callback); break; default: LOG(FATAL) << "Unsupported primitive type in near comparator: " @@ -761,7 +823,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual)); } - if (ShapeUtil::IsTuple(expected)) { + if (expected.IsTuple()) { if (ShapeUtil::TupleElementCount(expected) != ShapeUtil::TupleElementCount(actual)) { return InvalidArgument( @@ -776,8 +838,8 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { return AppendStatus(result, StrCat("mismatch in tuple index", i)); } } - } else if (ShapeUtil::IsArray(expected)) { - if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { + } else if (expected.IsArray()) { + if (expected.rank() != actual.rank()) { return InvalidArgument("want rank of %s got rank of %s", ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual)); @@ -831,7 +893,7 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { } Status Near(const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error, bool detailed_message, + const ErrorSpec& error, absl::optional detailed_message, const MiscompareCallback& miscompare_callback) { VLOG(1) << "Expected literal:"; XLA_VLOG_LINES(1, expected.ToString()); diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h index 9e5bf7c1d062ef0f25d07a80d6ded8106df5dacc..23fff3fa348f1652eaec344da4c40ccf3ad1079a 100644 --- a/tensorflow/compiler/xla/literal_comparison.h +++ b/tensorflow/compiler/xla/literal_comparison.h @@ -55,9 +55,10 @@ using MiscompareCallback = // being compared. // // If detailed_message is true, then the error message in the assertion result -// will contain a more detailed breakdown of mismatches. +// will contain a more detailed breakdown of mismatches. By default, we display +// a detailed message only for "large" inputs. Status Near(const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error, bool detailed_message, + const ErrorSpec& error, absl::optional detailed_message, const MiscompareCallback& miscompare_callback); // Calling ToString on a literal with over 100 million elements takes around diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index d8c7141cacb8f60cb4ce56d07ac5827a8dbf9b20..b54a71ae68218ef578535a913f5867d843236e32 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -118,6 +118,9 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto c64_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); EXPECT_EQ("c64[] (3.14, 2.78)", c64_lit.ToString()); + auto c128_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); + EXPECT_EQ("c128[] (3.14, 2.78)", c128_lit.ToString()); + auto bf16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); EXPECT_EQ("bf16[] 0.5", bf16_lit.ToString()); @@ -469,6 +472,21 @@ TEST_F(LiteralUtilTest, C64Equality) { EXPECT_NE(vector, vector_reversed); } +TEST_F(LiteralUtilTest, C128Equality) { + // Test equality with tuples. + auto vector = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + + // Tuple with the same elements. One element is shared with the original + // tuple, the other is a clone of the element in the original tuple. + auto vector_clone = + LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + EXPECT_EQ(vector, vector_clone); + + auto vector_reversed = + LiteralUtil::CreateR1({{3.0, 4.0}, {1.0, 2.0}}); + EXPECT_NE(vector, vector_reversed); +} + TEST_F(LiteralUtilTest, IsAllTuple) { auto element1 = LiteralUtil::CreateR0(0.0); auto element2 = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); @@ -623,7 +641,7 @@ template class LiteralUtilTestTemplated : public ::testing::Test {}; using TestedTypes = ::testing::Types; -TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes); +TYPED_TEST_SUITE(LiteralUtilTestTemplated, TestedTypes); TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { // Make a non-integer for floating point types. @@ -836,6 +854,13 @@ TEST_F(LiteralUtilTest, PopulateR1C64) { EXPECT_EQ(output, expected); } +TEST_F(LiteralUtilTest, PopulateR1C128) { + Literal output(ShapeUtil::MakeShape(C128, {1})); + output.PopulateR1({{77, 88}}); + auto expected = LiteralUtil::CreateR1({{77, 88}}); + EXPECT_EQ(output, expected); +} + TEST_F(LiteralUtilTest, PopulateR2C64) { Literal output(ShapeUtil::MakeShape(C64, {2, 2})); output.PopulateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); @@ -897,6 +922,14 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { EXPECT_EQ(output, expected); } +TEST_F(LiteralUtilTest, PopulateWithValueR2C128) { + Literal output(ShapeUtil::MakeShape(C128, {2, 2})); + output.PopulateWithValue({4, 2}); + auto expected = + LiteralUtil::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); + EXPECT_EQ(output, expected); +} + TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { Literal output(ShapeUtil::MakeShape(F16, {})); half h(0.25f); @@ -1237,11 +1270,21 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); + auto s16 = LiteralUtil::CreateR4WithLayout({{ + {{10, 0, 12, 0}, {0, 15, 0, 17}}, + {{0, 19, 0, 21}, {22, 0, 24, 0}}, + {{26, 0, 28, 0}, {0, 31, 0, 33}}, + }}, layout_r4_dim0major_); auto s32 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); + auto u16 = LiteralUtil::CreateR4WithLayout({{ + {{10, 0, 12, 0}, {0, 15, 0, 17}}, + {{0, 19, 0, 21}, {22, 0, 24, 0}}, + {{26, 0, 28, 0}, {0, 31, 0, 33}}, + }}, layout_r4_dim0major_); auto u32 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, @@ -1298,9 +1341,19 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, }}, layout_r4_dim0major_); - // clang-format on + auto c128 = LiteralUtil::CreateR4WithLayout({{ + {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}}, + {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}}, + {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}}, + }}, layout_r4_dim0major_); // clang-format on Literal conv; + conv = s8.Convert(U16).ConsumeValueOrDie(); + EXPECT_EQ(conv, u16); + + conv = s8.Convert(S16).ConsumeValueOrDie(); + EXPECT_EQ(conv, s16); + conv = s8.Convert(U32).ConsumeValueOrDie(); EXPECT_EQ(conv, u32); @@ -1352,12 +1405,26 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { conv = f16.Convert(C64).ConsumeValueOrDie(); EXPECT_EQ(conv, c64); + conv = s32.Convert(S16).ConsumeValueOrDie(); + EXPECT_EQ(conv, s16); + + conv = s32.Convert(U16).ConsumeValueOrDie(); + EXPECT_EQ(conv, u16); + + conv = s32.Convert(C128).ConsumeValueOrDie(); + EXPECT_EQ(conv, c128); + + conv = f16.Convert(C128).ConsumeValueOrDie(); + EXPECT_EQ(conv, c128); + EXPECT_EQ(s32.Convert(TUPLE).status().code(), tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(s32.Convert(S16).status().code(), tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(s32.Convert(U16).status().code(), tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(c128.Convert(F32).status().code(), + tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(c128.Convert(S32).status().code(), + tensorflow::error::UNIMPLEMENTED); } TEST_F(LiteralUtilTest, BitcastConvert) { @@ -1642,7 +1709,7 @@ TEST_F(LiteralUtilTest, MoveIntoTuple) { LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]})); Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements)); - ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); + ASSERT_TRUE(literal.shape().IsTuple()); ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3); EXPECT_EQ(literal.Get({}, /*shape_index=*/{0}), 1.0); @@ -1659,7 +1726,7 @@ TEST_F(LiteralUtilTest, MoveIntoTuple) { TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) { Literal literal = Literal::MoveIntoTuple({}); - ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); + ASSERT_TRUE(literal.shape().IsTuple()); EXPECT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0); } @@ -1719,7 +1786,8 @@ TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) { Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}), - ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})})); + ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {}), + ShapeUtil::MakeShape(C128, {})})); EXPECT_EQ(tuple.Get({}, {0}), 0.0); EXPECT_EQ(tuple.Get({0}, {1}), false); @@ -1727,6 +1795,7 @@ TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) { EXPECT_EQ(tuple.Get({0, 0}, {2}), 0); EXPECT_EQ(tuple.Get({1, 0}, {2}), 0); EXPECT_EQ(tuple.Get({}, {3}), complex64(0.0f, 0.0f)); + EXPECT_EQ(tuple.Get({}, {4}), complex128(0.0, 0.0)); } TEST_F(LiteralUtilTest, ProtoRoundTrip) { @@ -1736,6 +1805,8 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { auto vector_int8 = LiteralUtil::CreateR1({-128, 0, 2, 4, 7, 56, 127}); auto vector_uint8 = LiteralUtil::CreateR1({128, 0, 2, 56, 127, 255}); auto vector_c64 = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + auto vector_c128 = + LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); auto vector_bfloat16 = LiteralUtil::CreateR1( {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}}); auto vector_half = @@ -1756,6 +1827,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { EXPECT_EQ(vector_int8, to_from_proto(vector_int8)); EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8)); EXPECT_EQ(vector_c64, to_from_proto(vector_c64)); + EXPECT_EQ(vector_c128, to_from_proto(vector_c128)); EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16)); EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); EXPECT_EQ(tuple, to_from_proto(tuple)); diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index bb5e5e61000d0aca6ab052ac87d2fbcd96e55f70..26b029c8d0c52e38510f9279def7c4af2904931d 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -62,7 +62,7 @@ Literal ConvertType(LiteralSlice literal) { ShapeUtil::ForEachSubshape( literal.shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { if (subshape.element_type() == primitive_util::NativeToPrimitiveType()) { auto src = literal.data(shape_index); @@ -106,12 +106,16 @@ Literal ConvertType(LiteralSlice literal) { switch (primitive_type) { case U8: return LiteralUtil::CreateR0(0); + case U16: + return LiteralUtil::CreateR0(0); case U32: return LiteralUtil::CreateR0(0); case U64: return LiteralUtil::CreateR0(0); case S8: return LiteralUtil::CreateR0(0); + case S16: + return LiteralUtil::CreateR0(0); case S32: return LiteralUtil::CreateR0(0); case S64: @@ -126,11 +130,10 @@ Literal ConvertType(LiteralSlice literal) { return LiteralUtil::CreateR0(0); case C64: return LiteralUtil::CreateR0(0); + case C128: + return LiteralUtil::CreateR0(0); case PRED: return LiteralUtil::CreateR0(false); - case S16: - case U16: - LOG(FATAL) << "u16/s16 literals not yet implemented"; case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 0"; case OPAQUE: @@ -164,6 +167,8 @@ Literal ConvertType(LiteralSlice literal) { return LiteralUtil::CreateR0(1); case C64: return LiteralUtil::CreateR0(1); + case C128: + return LiteralUtil::CreateR0(1); case PRED: return LiteralUtil::CreateR0(true); case S16: @@ -200,6 +205,8 @@ Literal ConvertType(LiteralSlice literal) { -std::numeric_limits::infinity()); case C64: LOG(FATAL) << "C64 element type has no minimum value"; + case C128: + LOG(FATAL) << "C128 element type has no minimum value"; case PRED: return LiteralUtil::CreateR0(false); case S16: @@ -344,6 +351,10 @@ Literal ConvertType(LiteralSlice literal) { new_literal.Set(to_multi_index, literal.Get(from_multi_index)); break; + case C128: + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); + break; default: LOG(FATAL) << "Unhandled primitive element type: " << PrimitiveType_Name(literal.shape().element_type()); @@ -355,7 +366,7 @@ Literal ConvertType(LiteralSlice literal) { /* static */ Literal LiteralUtil::GetFirstScalarLiteral( const LiteralSlice& literal) { - CHECK(ShapeUtil::IsArray(literal.shape())); + CHECK(literal.shape().IsArray()); CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0); switch (literal.shape().element_type()) { case PRED: @@ -392,6 +403,10 @@ Literal ConvertType(LiteralSlice literal) { return LiteralUtil::CreateR0(literal.GetFirstElement()); case U64: return LiteralUtil::CreateR0(literal.GetFirstElement()); + + case C128: + return LiteralUtil::CreateR0( + literal.GetFirstElement()); default: LOG(FATAL) << "Unhandled primitive type " << literal.shape().element_type(); diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc index 4eab4fa4290c270697c00be20840cf4e85459183..bad65ac32018fafcc7634b989f1b4b0867aa5c0d 100644 --- a/tensorflow/compiler/xla/metric_table_report.cc +++ b/tensorflow/compiler/xla/metric_table_report.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/metric_table_report.h" -#include #include +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "tensorflow/core/platform/logging.h" @@ -55,7 +55,7 @@ string MetricTableReport::MakeReport(double expected_metric_sum) { const auto metric_greater = [](const Entry& a, const Entry& b) { return a.metric > b.metric; }; - std::sort(entries_.begin(), entries_.end(), metric_greater); + absl::c_sort(entries_, metric_greater); // Create the report AppendLine(); @@ -117,7 +117,7 @@ std::vector MetricTableReport::MakeCategories( auto metric_sum_greater = [](const Category& a, const Category& b) { return a.metric_sum > b.metric_sum; }; - std::sort(categories.begin(), categories.end(), metric_sum_greater); + absl::c_sort(categories, metric_sum_greater); return categories; } @@ -249,7 +249,7 @@ string MetricTableReport::MetricString(double metric) { string output; // Copy leading non-digit characters unconditionally. // This picks up the leading sign. - while (!sp1.empty() && !isdigit(sp1[0])) { + while (!sp1.empty() && !absl::ascii_isdigit(sp1[0])) { output.push_back(sp1[0]); sp1.remove_prefix(1); } diff --git a/tensorflow/compiler/xla/parse_flags_from_env.cc b/tensorflow/compiler/xla/parse_flags_from_env.cc index 5b568888d14f21c1330556d017eafba6c8dd2228..91e71f5d1d02d135158d0dffc140c21cf8ea5e3a 100644 --- a/tensorflow/compiler/xla/parse_flags_from_env.cc +++ b/tensorflow/compiler/xla/parse_flags_from_env.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "absl/strings/ascii.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" @@ -37,7 +38,7 @@ limitations under the License. namespace xla { -static const char kWS[] = " \t\r\n"; // whitespace +static const char kWS[] = " \t\r\n"; // whitespace // The following struct represents an argv[]-style array, parsed // from data gleaned from the environment. @@ -104,7 +105,8 @@ static void ParseArgvFromString(const string& flag_str, EnvArgv* a) { // Set e to the index just past the end of the flag. size_t e = b; while (e != flag_str.size() && isascii(flag_str[e]) && - (strchr("-_", flag_str[e]) != nullptr || isalnum(flag_str[e]))) { + (strchr("-_", flag_str[e]) != nullptr || + absl::ascii_isalnum(flag_str[e]))) { e++; } if (e != flag_str.size() && flag_str[e] == '=' && diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index 00ad01fc407017624a9183d69e61cb0d382e3f11..1eedddf72c1d393cb1b88e589881e24de02ad802 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -18,16 +18,32 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/numbers.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace primitive_util { +int SignificandWidth(PrimitiveType type) { + switch (type) { + case F32: + return std::numeric_limits::digits; + case F64: + return std::numeric_limits::digits; + case BF16: + return kBFloat16MantissaBits + 1; + case F16: + return 11; + default: + LOG(FATAL) << "Not a floating data type " << type; + } +} + bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64 || type == BF16; } -bool IsComplexType(PrimitiveType type) { return type == C64; } +bool IsComplexType(PrimitiveType type) { return type == C64 || type == C128; } bool IsSignedIntegralType(PrimitiveType type) { return type == S8 || type == S16 || type == S32 || type == S64; @@ -67,6 +83,9 @@ int BitWidth(PrimitiveType type) { case C64: return 64; + case C128: + return 128; + case TUPLE: LOG(FATAL) << "TUPLE is an invalid type for BitWidth"; @@ -78,10 +97,27 @@ int BitWidth(PrimitiveType type) { } } +xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth) { + switch (src_bitwidth) { + case 8: + return xla::U8; + case 16: + return xla::U16; + case 32: + return xla::U32; + case 64: + return xla::U64; + default: + return xla::PRIMITIVE_TYPE_INVALID; + } +} + PrimitiveType ComplexComponentType(PrimitiveType complex_type) { switch (complex_type) { case C64: return F32; + case C128: + return F64; default: LOG(FATAL) << "Primitive type is not complex: " << PrimitiveType_Name(complex_type); diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 70603b6fed1be50c427799e6dce7b8bf9631a6f4..295d353003276b4c1731f7d6a378fd1ae0288d3c 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -29,6 +29,10 @@ limitations under the License. namespace xla { namespace primitive_util { +// Returns the count of significand (mantissa) bits for float datatypes. +// For non-float datatypes, results in a LOG(FATAL). +int SignificandWidth(PrimitiveType type); + // The number of exponent bits in a BF16 value. const int kBFloat16ExponentBits = 8; @@ -126,6 +130,11 @@ inline PrimitiveType NativeToPrimitiveType() { return C64; } +template <> +inline PrimitiveType NativeToPrimitiveType() { + return C128; +} + bool IsFloatingPointType(PrimitiveType type); bool IsComplexType(PrimitiveType type); @@ -142,6 +151,8 @@ bool IsArrayType(PrimitiveType primitive_type); // Returns the number of bits in the representation for a given type. int BitWidth(PrimitiveType type); +PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth); + // Returns the real, imag component type underlying the given complex type. // LOG(FATAL)'s if complex_type is not complex. PrimitiveType ComplexComponentType(PrimitiveType complex_type); @@ -225,6 +236,11 @@ struct PrimitiveTypeToNative { using type = complex64; }; +template <> +struct PrimitiveTypeToNative { + using type = complex128; +}; + // Returns the lower-case name of the given primitive type. const string& LowercasePrimitiveTypeName(PrimitiveType s); diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index ddffafa9017a565f01c3214360a958e6840e9148..4afb21d5c8864c2974114af2de08df4106a13a8c 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -3,8 +3,8 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("//tensorflow/core:platform/default/build_config.bzl", "pyx_library") +load("//tensorflow/compiler/xla:xla.bzl", "xla_python_default_plugins") py_library( name = "xla_client", @@ -98,6 +98,11 @@ tf_py_wrap_cc( "local_computation_builder.i", "//tensorflow/python:platform/base.i", ], + version_script = select({ + "//tensorflow:darwin": "pywrap_xla_exported_symbols.lds", + "//tensorflow:windows": None, + "//conditions:default": "pywrap_xla_version_script.lds", + }), deps = [ ":local_computation_builder", ":numpy_bridge", @@ -105,7 +110,5 @@ tf_py_wrap_cc( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:cpu_plugin", - ] + if_cuda_is_configured([ - "//tensorflow/compiler/xla/service:gpu_plugin", - ]), + ] + xla_python_default_plugins(), ) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 657a09f92ad14d959416c768b09c392ff17f96eb..0e898d494e044509a41209891c28d929dff11b9a 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -102,7 +102,7 @@ int GetReplicaCount() { return g_replica_count; } -LocalClient* GetOrCreateLocalClient() { +StatusOr GetOrCreateLocalClient() { string* platform_name = GetPlatformNameString(); tensorflow::mutex_lock lock(g_local_client_mutex); if (g_local_client != nullptr) { @@ -111,7 +111,8 @@ LocalClient* GetOrCreateLocalClient() { LocalClientOptions options; options.set_platform(PlatformUtil::GetPlatform(*platform_name).ValueOrDie()); options.set_number_of_replicas(g_replica_count); - g_local_client = ClientLibrary::GetOrCreateLocalClient(options).ValueOrDie(); + TF_ASSIGN_OR_RETURN(g_local_client, + ClientLibrary::GetOrCreateLocalClient(options)); CHECK(g_local_client != nullptr); return g_local_client; } @@ -133,7 +134,7 @@ Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) { Status TransferToInfeedLocal(const Literal& literal) { VLOG(1) << "Infeeding literal without replica number; shape: " << literal.shape(); - LocalClient* client = GetOrCreateLocalClient(); + TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); return client->TransferToInfeedLocal(literal, /*device_ordinal=*/0); } @@ -141,7 +142,7 @@ Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number) { VLOG(1) << "Infeeding shape " << literal.shape() << " to replica number: " << replica_number; - LocalClient* client = GetOrCreateLocalClient(); + TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); TF_ASSIGN_OR_RETURN(int device_ordinal, client->ReplicaNumberToDeviceOrdinal(replica_number)); return client->TransferToInfeedLocal(literal, device_ordinal); @@ -151,7 +152,7 @@ StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, int replica_number) { VLOG(1) << "Outfeeding literal from replica number: " << replica_number << " shape: " << shape; - LocalClient* client = GetOrCreateLocalClient(); + TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); TF_ASSIGN_OR_RETURN(int device_ordinal, client->ReplicaNumberToDeviceOrdinal(replica_number)); return client->TransferFromOutfeedLocal(shape, device_ordinal); @@ -168,7 +169,7 @@ static StatusOr ToBuffer(LocalClient* client, StatusOr LocalShapedBuffer::FromLiteral( const Literal& argument, const absl::optional& shape_with_layout, int replica_number) { - LocalClient* client = GetOrCreateLocalClient(); + TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); TF_ASSIGN_OR_RETURN(int device_ordinal, client->ReplicaNumberToDeviceOrdinal(replica_number)); VLOG(1) << "Creating shaped buffer from literal on replica/ordinal: " @@ -198,7 +199,7 @@ const Shape& LocalShapedBuffer::shape() const { } StatusOr LocalShapedBuffer::ToLiteral() const { - LocalClient* client = GetOrCreateLocalClient(); + TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); return client->ShapedBufferToLiteral(*shaped_buffer()); } @@ -333,37 +334,34 @@ CompiledLocalComputation::CompiledLocalComputation( StatusOr CompiledLocalComputation::Execute( absl::Span argument_handles) { - LocalClient* client = GetOrCreateLocalClient(); - StatusOr device_ordinal_status = client->ReplicaNumberToDeviceOrdinal(0); + if (num_replicas() != 1) { + return InvalidArgument( + "Attempted to execute computation with %d replicas using Execute()", + num_replicas()); + } + TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); + TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, + client->backend().computation_placer()->AssignDevices( + 1, /*computation_count=*/1)); StatusOr result_buffer_status; - if (!device_ordinal_status.ok()) { - result_buffer_status = device_ordinal_status.status(); - } else { - const int device_ordinal = device_ordinal_status.ValueOrDie(); - VLOG(3) << "Replica 0 mapped to device ordinal for execution: " - << device_ordinal; - - std::vector argument_buffers; - argument_buffers.reserve(argument_handles.size()); - for (auto& handle : argument_handles) { - argument_buffers.push_back(handle->shaped_buffer()); - } - - DeviceAssignment device_assignment = - client->backend() - .computation_placer() - ->AssignDevices(1, /*computation_count=*/1) - .ConsumeValueOrDie(); + const int device_ordinal = device_assignment(0, 0); + VLOG(3) << "Replica 0 mapped to device ordinal for execution: " + << device_ordinal; + + std::vector argument_buffers; + argument_buffers.reserve(argument_handles.size()); + for (auto& handle : argument_handles) { + argument_buffers.push_back(handle->shaped_buffer()); + } - ExecutableRunOptions options; - options.set_device_ordinal(device_ordinal); - options.set_allocator(client->backend().memory_allocator()); - options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment); + ExecutableRunOptions options; + options.set_device_ordinal(device_ordinal); + options.set_allocator(client->backend().memory_allocator()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); - result_buffer_status = executable_->Run(argument_buffers, options); - } + result_buffer_status = executable_->Run(argument_buffers, options); if (!result_buffer_status.ok()) { return InternalError( @@ -376,29 +374,30 @@ StatusOr CompiledLocalComputation::Execute( StatusOr CompiledLocalComputation::ExecutePerReplica( absl::Span> argument_handles) { - LocalClient* client = GetOrCreateLocalClient(); - const int num_replicas = GetReplicaCount(); + TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); + const int num_devices = client->device_count(); - if (argument_handles.size() != num_replicas) { + if (argument_handles.size() != num_replicas()) { return InvalidArgument( "Attempted to execute with %d replicas when replica count is %d", - argument_handles.size(), num_replicas); + argument_handles.size(), num_devices); } + if (argument_handles.size() > num_devices) { + return InvalidArgument( + "Attempted to execute with %d replicas when device count is %d", + argument_handles.size(), num_devices); + } + + VLOG(1) << "Executing with " << num_replicas() << " replicas."; - VLOG(1) << "Executing with " << num_replicas << " replicas."; + TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, + client->backend().computation_placer()->AssignDevices( + num_replicas(), /*computation_count=*/1)); - // Each replica populates a StatusOr result, but only the output value of - // replica zero is returned. - std::vector> results(num_replicas); - auto execute = [this, client, num_replicas, &argument_handles, + std::vector> results(num_replicas()); + auto execute = [this, client, &device_assignment, &argument_handles, &results](int replica) { - StatusOr device_ordinal_status = - client->ReplicaNumberToDeviceOrdinal(replica); - if (!device_ordinal_status.ok()) { - results[replica] = device_ordinal_status.status(); - return; - } - const int device_ordinal = device_ordinal_status.ValueOrDie(); + const int device_ordinal = device_assignment(replica, 0); VLOG(3) << "Replica " << replica << " mapped to device ordinal for execution: " << device_ordinal; @@ -408,12 +407,6 @@ StatusOr CompiledLocalComputation::ExecutePerReplica( argument_buffers.push_back(handle->shaped_buffer()); } - DeviceAssignment device_assignment = - client->backend() - .computation_placer() - ->AssignDevices(num_replicas, /*computation_count=*/1) - .ConsumeValueOrDie(); - ExecutableRunOptions options; options.set_device_ordinal(device_ordinal); options.set_allocator(client->backend().memory_allocator()); @@ -426,23 +419,23 @@ StatusOr CompiledLocalComputation::ExecutePerReplica( results[replica] = std::move(result_buffer_status); }; - if (num_replicas == 1) { + if (num_replicas() == 1) { // Fast-path if there is only one replica — run the computation on the // current thread. execute(0); } else { // TODO(phawkins): don't recreate the threadpool for each execution. tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun", - num_replicas - 1); + num_replicas() - 1); - for (int replica = 0; replica < num_replicas - 1; ++replica) { + for (int replica = 0; replica < num_replicas() - 1; ++replica) { pool.Schedule([&execute, replica] { execute(replica); }); } - execute(num_replicas - 1); + execute(num_replicas() - 1); } - std::vector wrapped_results(num_replicas); - for (int replica = 0; replica < num_replicas; ++replica) { + std::vector wrapped_results(num_replicas()); + for (int replica = 0; replica < num_replicas(); ++replica) { auto& statusor = results[replica]; if (!statusor.ok()) { return InternalError( @@ -549,7 +542,7 @@ StatusOr LocalComputation::Compile( argument_shape_pointers.push_back(&argument_shape); } - LocalClient* client = GetOrCreateLocalClient(); + TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); ExecutableBuildOptions options; if (build_options != nullptr) { options = *build_options; @@ -698,8 +691,9 @@ LocalOp LocalComputationBuilder::Collapse(const LocalOp& operand, return xla::Collapse(operand.op(), dimensions); } -LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) { - return xla::CrossReplicaSum(operand.op()); +LocalOp LocalComputationBuilder::CrossReplicaSum( + const LocalOp& operand, absl::Span replica_groups) { + return xla::CrossReplicaSum(operand.op(), replica_groups); } LocalOp LocalComputationBuilder::Slice(const LocalOp& operand, @@ -927,6 +921,22 @@ LocalOp LocalComputationBuilder::TriangularSolve(const LocalOp& a, conjugate_a); } +LocalOp LocalComputationBuilder::Gather( + const LocalOp& input, const LocalOp& start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes) { + return xla::Gather(input.op(), start_indices.op(), dimension_numbers, + slice_sizes); +} + +LocalOp LocalComputationBuilder::Scatter( + const LocalOp& input, const LocalOp& scatter_indices, + const LocalOp& updates, const LocalComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers) { + return xla::Scatter(input.op(), scatter_indices.op(), updates.op(), + update_computation.computation(), dimension_numbers); +} + StatusOr LocalComputationBuilder::BuildConstantSubGraph( const LocalOp& operand) { TF_ASSIGN_OR_RETURN(XlaComputation computation, @@ -1041,7 +1051,7 @@ StatusOr DestructureLocalShapedBufferTuple( LocalShapedBuffer* local_shaped_buffer) { const Shape tuple_shape = local_shaped_buffer->shape(); - if (!ShapeUtil::IsTuple(tuple_shape)) { + if (!tuple_shape.IsTuple()) { return InvalidArgument( "Attemped to destructure a LocalShapedBuffer that did not have a tuple " "shape; shape: %s", @@ -1088,7 +1098,7 @@ StatusOr DestructureXrtAllocationTuple( XrtAllocation* allocation, const string& session_target) { const Shape& tuple_shape = allocation->shape(); - if (!ShapeUtil::IsTuple(tuple_shape)) { + if (!tuple_shape.IsTuple()) { return InvalidArgument( "Attemped to destructure a LocalShapedBuffer that did not have a tuple " "shape; shape: %s", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 5e8341592100bc1eba4d1c17b0c2dd0e0888fdb1..6170567f9ff8f5a062f47d148900fe3676a74542 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -180,6 +180,10 @@ class CompiledLocalComputation { public: CompiledLocalComputation(std::unique_ptr executable); + int num_replicas() const { + return executable_->build_options().num_replicas(); + } + StatusOr Execute( absl::Span argument_handles); @@ -312,7 +316,8 @@ class LocalComputationBuilder { LocalOp Collapse(const LocalOp& operand, absl::Span dimensions); - LocalOp CrossReplicaSum(const LocalOp& operand); + LocalOp CrossReplicaSum(const LocalOp& operand, + absl::Span replica_groups); LocalOp Slice(const LocalOp& operand, absl::Span start_indices, absl::Span limit_indices, @@ -418,6 +423,15 @@ class LocalComputationBuilder { LocalOp TriangularSolve(const LocalOp& a, const LocalOp& b, bool left_side, bool lower, bool transpose_a, bool conjugate_a); + LocalOp Gather(const LocalOp& input, const LocalOp& start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes); + + LocalOp Scatter(const LocalOp& input, const LocalOp& scatter_indices, + const LocalOp& updates, + const LocalComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers); + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index bf5d667c6a12972845735983a74264ea05675971..6a85ed62dea3dbdbb25a990e6d774a0152439673 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -34,6 +34,9 @@ limitations under the License. // PaddingConfig proto <- corresponding Python proto // ConvolutionDimensionNumbers proto <- corresponding Python proto // DotDimensionNumbers proto <- corresponding Python proto +// GatherDimensionNumbers proto <- corresponding Python proto +// ScatterDimensionNumbers proto <- corresponding Python proto +// Span <- sequence of ReplicaGroup Python proto // // Arrows indicate whether a conversion only ever occurs in one // direction, or whether it is maintained bidirectionally. @@ -167,8 +170,41 @@ bool HandleStringAttribute(PyObject* o, return true; // Handled string attribute, ok! } +bool HandleRepeatedInt64Attribute( + PyObject* o, const char* attr_name, + tensorflow::protobuf::RepeatedField* field) { + PyObject* seq = PyObject_GetAttrString(o, attr_name); + if (!seq) { + return false; + } + + int length = PySequence_Size(seq); + if (length == -1) { + Py_DECREF(seq); + return false; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(seq, i); + if (!item) { + Py_DECREF(seq); + return false; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(seq); + return false; + } + *field->Add() = dimension; + Py_DECREF(item); + } + Py_DECREF(seq); + return true; } -} + +} // namespace swig +} // namespace xla %} // Required to use PyArray_* functions. @@ -657,128 +693,27 @@ tensorflow::ImportNumpy(); %typemap(in) const DotDimensionNumbers& (DotDimensionNumbers dimension_numbers) { - int length; - - /* lhs_contracting_dimensions */ - PyObject* lhs_contracting_dimensions = PyObject_GetAttrString( - $input, "lhs_contracting_dimensions"); - if (!lhs_contracting_dimensions) { + if (!HandleRepeatedInt64Attribute( + $input, "lhs_contracting_dimensions", + dimension_numbers.mutable_lhs_contracting_dimensions())) { SWIG_fail; } - - length = PySequence_Size(lhs_contracting_dimensions); - if (length == -1) { - Py_DECREF(lhs_contracting_dimensions); + if (!HandleRepeatedInt64Attribute( + $input, "rhs_contracting_dimensions", + dimension_numbers.mutable_rhs_contracting_dimensions())) { SWIG_fail; } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(lhs_contracting_dimensions, i); - if (!item) { - Py_DECREF(lhs_contracting_dimensions); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(lhs_contracting_dimensions); - SWIG_fail; - } - dimension_numbers.add_lhs_contracting_dimensions(dimension); - Py_DECREF(item); - } - Py_DECREF(lhs_contracting_dimensions); - - /* rhs_contracting_dimensions */ - PyObject* rhs_contracting_dimensions = PyObject_GetAttrString( - $input, "rhs_contracting_dimensions"); - if (!lhs_contracting_dimensions) { + if (!HandleRepeatedInt64Attribute( + $input, "lhs_batch_dimensions", + dimension_numbers.mutable_lhs_batch_dimensions())) { SWIG_fail; } - - length = PySequence_Size(rhs_contracting_dimensions); - if (length == -1) { - Py_DECREF(rhs_contracting_dimensions); + if (!HandleRepeatedInt64Attribute( + $input, "rhs_batch_dimensions", + dimension_numbers.mutable_rhs_batch_dimensions())) { SWIG_fail; } - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(rhs_contracting_dimensions, i); - if (!item) { - Py_DECREF(rhs_contracting_dimensions); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(rhs_contracting_dimensions); - SWIG_fail; - } - dimension_numbers.add_rhs_contracting_dimensions(dimension); - Py_DECREF(item); - } - Py_DECREF(rhs_contracting_dimensions); - - /* lhs_batch_dimensions */ - PyObject* lhs_batch_dimensions = PyObject_GetAttrString( - $input, "lhs_batch_dimensions"); - if (!lhs_batch_dimensions) { - SWIG_fail; - } - - length = PySequence_Size(lhs_batch_dimensions); - if (length == -1) { - Py_DECREF(lhs_batch_dimensions); - SWIG_fail; - } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(lhs_batch_dimensions, i); - if (!item) { - Py_DECREF(lhs_batch_dimensions); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(lhs_batch_dimensions); - SWIG_fail; - } - dimension_numbers.add_lhs_batch_dimensions(dimension); - Py_DECREF(item); - } - Py_DECREF(lhs_batch_dimensions); - - /* rhs_batch_dimensions */ - PyObject* rhs_batch_dimensions = PyObject_GetAttrString( - $input, "rhs_batch_dimensions"); - if (!rhs_batch_dimensions) { - SWIG_fail; - } - - length = PySequence_Size(rhs_batch_dimensions); - if (length == -1) { - Py_DECREF(rhs_batch_dimensions); - SWIG_fail; - } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(rhs_batch_dimensions, i); - if (!item) { - Py_DECREF(rhs_batch_dimensions); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(rhs_batch_dimensions); - SWIG_fail; - } - dimension_numbers.add_rhs_batch_dimensions(dimension); - Py_DECREF(item); - } - Py_DECREF(rhs_batch_dimensions); - $1 = &dimension_numbers; } @@ -860,90 +795,108 @@ tensorflow::ImportNumpy(); } dimension_numbers.set_kernel_input_feature_dimension(value); - PyObject* o; - int length; - - o = PyObject_GetAttrString($input, "input_spatial_dimensions"); - if (!o) { + if (!HandleRepeatedInt64Attribute( + $input, "input_spatial_dimensions", + dimension_numbers.mutable_input_spatial_dimensions())) { SWIG_fail; } - length = PySequence_Size(o); - if (length == -1) { - Py_DECREF(o); + if (!HandleRepeatedInt64Attribute( + $input, "kernel_spatial_dimensions", + dimension_numbers.mutable_kernel_spatial_dimensions())) { SWIG_fail; } - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(o, i); - if (!item) { - Py_DECREF(o); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(o); - SWIG_fail; - } - dimension_numbers.add_input_spatial_dimensions(dimension); - Py_DECREF(item); + if (!HandleRepeatedInt64Attribute( + $input, "output_spatial_dimensions", + dimension_numbers.mutable_output_spatial_dimensions())) { + SWIG_fail; } - Py_DECREF(o); - o = PyObject_GetAttrString($input, "kernel_spatial_dimensions"); - if (!o) { + $1 = &dimension_numbers; +} + +// GatherDimensionNumbers + +%typemap(in) const GatherDimensionNumbers& + (GatherDimensionNumbers dimension_numbers) { + if (!HandleRepeatedInt64Attribute( + $input, "offset_dims", + dimension_numbers.mutable_offset_dims())) { SWIG_fail; } - length = PySequence_Size(o); - if (length == -1) { - Py_DECREF(o); + if (!HandleRepeatedInt64Attribute( + $input, "collapsed_slice_dims", + dimension_numbers.mutable_collapsed_slice_dims())) { SWIG_fail; } - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(o, i); - if (!item) { - Py_DECREF(o); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(o); - SWIG_fail; - } - dimension_numbers.add_kernel_spatial_dimensions(dimension); - Py_DECREF(item); + if (!HandleRepeatedInt64Attribute( + $input, "start_index_map", + dimension_numbers.mutable_start_index_map())) { + SWIG_fail; } - Py_DECREF(o); - o = PyObject_GetAttrString($input, "output_spatial_dimensions"); - if (!o) { + int64 value; + if (!GetIntAttr($input, "index_vector_dim", &value)) { SWIG_fail; } - length = PySequence_Size(o); - if (length == -1) { - Py_DECREF(o); + dimension_numbers.set_index_vector_dim(value); + + $1 = &dimension_numbers; +} + +// ScatterDimensionNumbers + +%typemap(in) const ScatterDimensionNumbers& + (ScatterDimensionNumbers dimension_numbers) { + if (!HandleRepeatedInt64Attribute( + $input, "update_window_dims", + dimension_numbers.mutable_update_window_dims())) { SWIG_fail; } - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(o, i); - if (!item) { - Py_DECREF(o); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(o); - SWIG_fail; - } - dimension_numbers.add_output_spatial_dimensions(dimension); - Py_DECREF(item); + if (!HandleRepeatedInt64Attribute( + $input, "inserted_window_dims", + dimension_numbers.mutable_inserted_window_dims())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "scatter_dims_to_operand_dims", + dimension_numbers.mutable_scatter_dims_to_operand_dims())) { + SWIG_fail; + } + + int64 value; + if (!GetIntAttr($input, "index_vector_dim", &value)) { + SWIG_fail; } - Py_DECREF(o); + dimension_numbers.set_index_vector_dim(value); $1 = &dimension_numbers; } +// Span + +%typemap(in) absl::Span + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + ReplicaGroup rgrp; + if (!HandleRepeatedInt64Attribute( + o, "replica_ids", + rgrp.mutable_replica_ids())) { + SWIG_fail; + } + temps.push_back(rgrp); + Py_DECREF(o); + } + $1 = temps; +} + + // ExecutableBuildOptions %typemap(in) const ExecutableBuildOptions* @@ -1000,6 +953,12 @@ tensorflow::ImportNumpy(); } Py_DECREF(o); + int64 num_replicas; + if (!GetIntAttr($input, "num_replicas", &num_replicas)) { + SWIG_fail; + } + build_options.set_num_replicas(num_replicas); + $1 = &build_options; } } @@ -1151,6 +1110,8 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::QR; %unignore xla::swig::LocalComputationBuilder::TriangularSolve; %unignore xla::swig::LocalComputationBuilder::CustomCall; +%unignore xla::swig::LocalComputationBuilder::Gather; +%unignore xla::swig::LocalComputationBuilder::Scatter; %unignore xla::swig::DeleteLocalComputation; %unignore xla::swig::DestructureLocalShapedBufferTuple; %unignore xla::swig::DestructureXrtAllocationTuple; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index b0aa024c7474cf8e6934432b2f364be464714999..52c5c621f7294c5da341879d15b77559fe870551 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -54,6 +54,8 @@ int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { return NPY_FLOAT64; case C64: return NPY_COMPLEX64; + case C128: + return NPY_COMPLEX128; case TUPLE: return NPY_OBJECT; default: @@ -89,6 +91,8 @@ PrimitiveType NumpyTypeToPrimitiveType(int np_type) { return F64; case NPY_COMPLEX64: return C64; + case NPY_COMPLEX128: + return C128; case NPY_OBJECT: return TUPLE; default: @@ -111,6 +115,7 @@ bool NumpyTypeIsValid(int np_type) { case NPY_FLOAT32: case NPY_FLOAT64: case NPY_COMPLEX64: + case NPY_COMPLEX128: case NPY_OBJECT: return true; default: @@ -123,7 +128,7 @@ PyObject* PyShapeInfoFromXlaShape(const Shape& shape) { PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum); PyObject* dimensions; - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { int num_elements = ShapeUtil::TupleElementCount(shape); dimensions = PyTuple_New(ShapeUtil::TupleElementCount(shape)); for (int i = 0; i < num_elements; ++i) { @@ -132,7 +137,7 @@ PyObject* PyShapeInfoFromXlaShape(const Shape& shape) { PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i))); } } else { - int rank = ShapeUtil::Rank(shape); + int rank = shape.rank(); dimensions = PyTuple_New(rank); for (int i = 0; i < rank; ++i) { PyTuple_SET_ITEM(dimensions, i, @@ -345,7 +350,7 @@ StatusOr OpMetadataFromPyObject(PyObject* o) { } PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { - if (ShapeUtil::IsTuple(literal.shape())) { + if (literal.shape().IsTuple()) { int num_elements = ShapeUtil::TupleElementCount(literal.shape()); PyObject* tuple = PyTuple_New(num_elements); for (int i = 0; i < num_elements; i++) { @@ -354,7 +359,7 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { } return tuple; } else { - int rank = ShapeUtil::Rank(literal.shape()); + int rank = literal.shape().rank(); std::vector dimensions(rank); // NOLINT - PyArray requires a long* for (int i = 0; i < rank; i++) { dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i); @@ -430,6 +435,9 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, case NPY_COMPLEX64: CopyNumpyArrayToLiteral(py_array, literal); break; + case NPY_COMPLEX128: + CopyNumpyArrayToLiteral(py_array, literal); + break; default: return InvalidArgument( "No XLA literal container for Numpy type number: %d", np_type); @@ -470,6 +478,9 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, case NPY_COMPLEX64: CopyLiteralToNumpyArray(literal, py_array); break; + case NPY_COMPLEX128: + CopyLiteralToNumpyArray(literal, py_array); + break; default: LOG(FATAL) << "No XLA literal container for Numpy type" << np_type; } diff --git a/tensorflow/compiler/xla/python/pywrap_xla_exported_symbols.lds b/tensorflow/compiler/xla/python/pywrap_xla_exported_symbols.lds new file mode 100644 index 0000000000000000000000000000000000000000..bce6c1acf8a1cc0005ca93e0466c5a0e29d880de --- /dev/null +++ b/tensorflow/compiler/xla/python/pywrap_xla_exported_symbols.lds @@ -0,0 +1 @@ +_PyInit__pywrap_xla diff --git a/tensorflow/compiler/xla/python/pywrap_xla_version_script.lds b/tensorflow/compiler/xla/python/pywrap_xla_version_script.lds new file mode 100644 index 0000000000000000000000000000000000000000..d31cfce7be7b6accf05ef77f3485904099965afc --- /dev/null +++ b/tensorflow/compiler/xla/python/pywrap_xla_version_script.lds @@ -0,0 +1,6 @@ +xla { + global: + PyInit_*; + local: + *; +}; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 378bbdcb175f10d73da87f5286cf5129477a124c..8964b158292371d662368cfb0b644667985f719e 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -199,6 +199,7 @@ XLA_ELEMENT_TYPE_TO_DTYPE = { xla_data_pb2.F32: np.dtype('float32'), xla_data_pb2.F64: np.dtype('float64'), xla_data_pb2.C64: np.dtype('complex64'), + xla_data_pb2.C128: np.dtype('complex128'), xla_data_pb2.TUPLE: np.dtype(np.object), } @@ -458,6 +459,7 @@ class CompileOptions(object): self.dump_unoptimized_hlo_proto_to = None self.dump_per_pass_hlo_proto_to = None self.hlo_profile = False + self.num_replicas = get_replica_count() def transfer_to_infeed(value, replica_number=None): @@ -963,16 +965,30 @@ class ComputationBuilder(object): dimensions = tuple(range(ndim)) return self._client.Reshape(operand, dimensions, new_sizes) - def CrossReplicaSum(self, operand): + def CrossReplicaSum(self, operand, replica_groups=None): """CrossReplicaSum op. Args: operand: the operand to sum across replica instances. + replica_groups: optional, list of lists of ints encoding a partition of + the set {0, 1, ..., num_replicas} into equally-sized replica groups + within which the cross-replica sum is performed. If not supplied or None + (the default), all replicas belong to the same group. Returns: - A LocalOp that has the sum of the value among all replicas. + A LocalOp that represents on each replica the sum of its group's values. """ - return self._client.CrossReplicaSum(operand) + + def make_proto(replica_group): + replica_group_proto = xla_data_pb2.ReplicaGroup() + replica_group_proto.replica_ids.extend(replica_group) + return replica_group_proto + + if replica_groups is None: + replica_groups = [] # special value for XLA API + else: + replica_groups = [make_proto(group) for group in replica_groups] + return self._client.CrossReplicaSum(operand, replica_groups) def Collapse(self, operand, dimensions): """Collapse op.""" @@ -1477,6 +1493,18 @@ class ComputationBuilder(object): return self._client.TriangularSolve( a, b, left_side, lower, transpose_a, conjugate_a) + def Gather(self, a, start_indices, dimension_numbers, slice_sizes): + """Enqueues a Gather operation onto the computation.""" + return self._client.Gather(a, start_indices, dimension_numbers, + slice_sizes) + + def Scatter(self, a, scatter_indices, updates, update_computation, + dimension_numbers): + """Enqueues a Scatter operation onto the computation.""" + return self._client.Scatter( + a, scatter_indices, updates, update_computation.computation, + dimension_numbers,) + def _forward_methods_to_local_builder(): """Forward remaining ComputationBuilder methods to the C API. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 002a20e60a9fbe117af991731a555e60eef9397a..54c76241b9929fb39a6d63648f8ff35d78534b28 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -565,6 +565,18 @@ class SingleOpTest(LocalComputationTest): c.CrossReplicaSum(c.Constant(lhs)) self._ExecuteAndCompareExact(c, expected=lhs) + def testCrossReplicaSumOneReplicaWithSingletonGroup(self): + samples = [ + NumpyArrayF32(42.0), + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples: + c = self._NewComputation() + c.CrossReplicaSum(c.Constant(lhs), [[0]]) + self._ExecuteAndCompareExact(c, expected=lhs) + def testDotMatrixVectorF32(self): c = self._NewComputation() lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) @@ -1129,6 +1141,21 @@ class SingleOpTest(LocalComputationTest): self.assertFalse(c.IsConstant(non_const_expr)) # self.assertTrue(c.IsConstant(c.Sub(c.Add(x, a), x))) # TODO(b/77245564) + def testGather(self): + a = np.arange(9).astype(np.int32).reshape((3, 3)) + indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32) + dnums = xla_client.xla_data_pb2.GatherDimensionNumbers() + dnums.offset_dims.append(1) + dnums.offset_dims.append(2) + dnums.start_index_map.append(0) + dnums.start_index_map.append(1) + dnums.index_vector_dim = 2 + c = self._NewComputation() + c.Gather(c.Constant(a), c.Constant(indices), dnums, slice_sizes=[1, 1]) + g = self._Execute(c, ()) + expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32) + np.testing.assert_allclose(g, expected, rtol=1e-4) + class EmbeddedComputationsTest(LocalComputationTest): """Tests for XLA graphs with embedded computations (such as maps).""" @@ -1186,6 +1213,14 @@ class EmbeddedComputationsTest(LocalComputationTest): c.Mul(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(2.0)) return c.Build() + def _CreateBinaryAddS32Computation(self): + """Computation (s32, s32) -> s32 that adds its two parameters.""" + c = self._NewComputation("add_param0_by_param1") + c.Add( + c.ParameterFromNumpy(NumpyArrayS32(0)), + c.ParameterFromNumpy(NumpyArrayS32(0))) + return c.Build() + def _CreateBinaryAddF32Computation(self): """Computation (f32, f32) -> f32 that adds its two parameters.""" c = self._NewComputation("add_param0_by_param1") @@ -1568,6 +1603,23 @@ class EmbeddedComputationsTest(LocalComputationTest): execution.join() self.assertEqual(want, got) + def testScatter(self): + a = np.arange(9).astype(np.int32).reshape((3, 3)) + scatter_indices = np.array([0, 2], dtype=np.int32) + updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32) + + dnums = xla_client.xla_data_pb2.ScatterDimensionNumbers() + dnums.update_window_dims.append(1) + dnums.inserted_window_dims.append(0) + dnums.scatter_dims_to_operand_dims.append(0) + dnums.index_vector_dim = 1 + + c = self._NewComputation() + c.Scatter(c.Constant(a), c.Constant(scatter_indices), c.Constant(updates), + self._CreateBinaryAddS32Computation(), dnums) + expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], dtype=np.int32) + self._ExecuteAndCompareClose(c, expected=expected) + class ErrorTest(LocalComputationTest): diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 92f28a9f8aaa3106b9a58ae1ee93ef8841ab58ef..08b78ee244844f41d551d7e249cec0cbf157d639 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -550,9 +551,9 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( HloEvaluator evaluator; Literal result_literal = - evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); + evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); - CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4); + CHECK_EQ(result_literal.shape().rank(), 4); auto result = absl::make_unique>(result_literal.shape().dimensions(0), result_literal.shape().dimensions(1), @@ -605,24 +606,26 @@ ReferenceUtil::ReduceToRowArray2D( const std::function& reduce_function) { std::vector result; CHECK_EQ(dims.size(), 3); - const std::set dim_set(dims.begin(), dims.end()); + const absl::flat_hash_set dim_set(dims.begin(), dims.end()); CHECK_EQ(dim_set.size(), 3); - for (int64 a0 = 0; a0 == 0 || (!dim_set.count(0) && a0 < array.n1()); ++a0) { - for (int64 a1 = 0; a1 == 0 || (!dim_set.count(1) && a1 < array.n2()); + for (int64 a0 = 0; a0 == 0 || (!dim_set.contains(0) && a0 < array.n1()); + ++a0) { + for (int64 a1 = 0; a1 == 0 || (!dim_set.contains(1) && a1 < array.n2()); ++a1) { - for (int64 a2 = 0; a2 == 0 || (!dim_set.count(2) && a2 < array.n3()); + for (int64 a2 = 0; a2 == 0 || (!dim_set.contains(2) && a2 < array.n3()); ++a2) { - for (int64 a3 = 0; a3 == 0 || (!dim_set.count(3) && a3 < array.n4()); + for (int64 a3 = 0; a3 == 0 || (!dim_set.contains(3) && a3 < array.n4()); ++a3) { float accumulator = init; - for (int64 i0 = 0; i0 == 0 || (dim_set.count(0) && i0 < array.n1()); - ++i0) { - for (int64 i1 = 0; i1 == 0 || (dim_set.count(1) && i1 < array.n2()); - ++i1) { + for (int64 i0 = 0; + i0 == 0 || (dim_set.contains(0) && i0 < array.n1()); ++i0) { + for (int64 i1 = 0; + i1 == 0 || (dim_set.contains(1) && i1 < array.n2()); ++i1) { for (int64 i2 = 0; - i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) { + i2 == 0 || (dim_set.contains(2) && i2 < array.n3()); ++i2) { for (int64 i3 = 0; - i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) { + i3 == 0 || (dim_set.contains(3) && i3 < array.n4()); + ++i3) { // Handle zero-sized arrays. if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 && array.n4() > 0) { diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc index d8123a6de28ca532819ece4a75cd0b725f8c1bbd..22b4218fbd5e9bc59a0de22735eb51db46670f09 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service.cc @@ -47,6 +47,14 @@ namespace xla { }); } +::grpc::Status GRPCService::GetDeviceHandles(::grpc::ServerContext* context, + const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->GetDeviceHandles(arg, result); + }); +} + ::grpc::Status GRPCService::Compile(::grpc::ServerContext* /*context*/, const CompileRequest* arg, CompileResponse* result) { @@ -61,6 +69,14 @@ namespace xla { [this, arg, result]() { return service_->Execute(arg, result); }); } +::grpc::Status GRPCService::ExecuteGraphParallel( + ::grpc::ServerContext* /*context*/, const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->ExecuteGraphParallel(arg, result); + }); +} + ::grpc::Status GRPCService::WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) { diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h index 3e586b288a56a22573d0c3b9ae7b2f25fdbf851a..b546704f73e34941cbf7bc2fe08062aa438039f7 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.h +++ b/tensorflow/compiler/xla/rpc/grpc_service.h @@ -39,6 +39,10 @@ class GRPCService : public grpc::XlaService::Service { const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; + ::grpc::Status GetDeviceHandles(::grpc::ServerContext* context, + const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override; + ::grpc::Status Compile(::grpc::ServerContext* context, const CompileRequest* arg, CompileResponse* result) override; @@ -46,6 +50,9 @@ class GRPCService : public grpc::XlaService::Service { ::grpc::Status Execute(::grpc::ServerContext* context, const ExecuteRequest* arg, ExecuteResponse* result) override; + ::grpc::Status ExecuteGraphParallel(::grpc::ServerContext* context, + const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) override; ::grpc::Status WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index d8736c819687482a9dead57bdeacff8e75dce105..34af6b35972e8e484eee3d5419da17095556aebc 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1,6 +1,14 @@ # Description: # XLA service implementation. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") +load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_py", +) +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") + licenses(["notice"]) # Apache 2.0 package(default_visibility = [":friends"]) @@ -12,15 +20,6 @@ package_group( ], ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") -load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load( - "//tensorflow/core:platform/default/build_config.bzl", - "tf_proto_library_py", -) - xla_proto_library( name = "hlo_proto", srcs = ["hlo.proto"], @@ -224,19 +223,23 @@ cc_library( "hlo_evaluator_typed_visitor.h", "hlo_evaluator_typed_visitor_bfloat16.cc", "hlo_evaluator_typed_visitor_bool.cc", + "hlo_evaluator_typed_visitor_complex128.cc", "hlo_evaluator_typed_visitor_complex64.cc", "hlo_evaluator_typed_visitor_double.cc", "hlo_evaluator_typed_visitor_float.cc", "hlo_evaluator_typed_visitor_half.cc", + "hlo_evaluator_typed_visitor_int16.cc", "hlo_evaluator_typed_visitor_int32.cc", "hlo_evaluator_typed_visitor_int64.cc", "hlo_evaluator_typed_visitor_int8.cc", + "hlo_evaluator_typed_visitor_uint16.cc", "hlo_evaluator_typed_visitor_uint32.cc", "hlo_evaluator_typed_visitor_uint64.cc", "hlo_evaluator_typed_visitor_uint8.cc", ], hdrs = ["hlo_evaluator.h"], deps = [ + ":dynamic_dimension_inference", ":hlo", ":hlo_casting_utils", ":hlo_query", @@ -257,6 +260,7 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", + "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", @@ -268,6 +272,7 @@ tf_cc_test( srcs = ["hlo_evaluator_test.cc"], deps = [ ":hlo", + ":hlo_element_type_converter", ":hlo_evaluator", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", @@ -280,7 +285,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -516,6 +520,7 @@ tf_cc_test( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -678,6 +683,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -696,6 +702,7 @@ cc_library( ":compiler", ":computation_layout", ":device_memory_allocator", + ":dynamic_dimension_inference", ":executable", ":execution_tracker", ":hlo", @@ -1003,6 +1010,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1054,7 +1062,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -1092,7 +1099,6 @@ cc_library( ":buffer_value_containers", ":heap_simulator", ":hlo", - ":hlo_memory_scheduler", ":hlo_proto", ":logical_buffer", ":tuple_points_to_analysis", @@ -1137,6 +1143,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) @@ -1231,7 +1238,6 @@ cc_library( deps = [ ":hlo", ":hlo_proto", - "//tensorflow/compiler/xla:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], @@ -1499,7 +1505,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], @@ -1580,6 +1585,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1720,9 +1727,9 @@ cc_library( ) tf_cc_test( - name = "convolution_feature_group_converter_test", + name = "convolution_group_converter_test", size = "small", - srcs = ["convolution_feature_group_converter_test.cc"], + srcs = ["convolution_group_converter_test.cc"], deps = [ ":convolution_group_converter", ":hlo", @@ -1866,8 +1873,9 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -1931,6 +1939,46 @@ cc_library( ], ) +cc_library( + name = "dynamic_padder", + srcs = ["dynamic_padder.cc"], + hdrs = ["dynamic_padder.h"], + deps = [ + ":dynamic_dimension_inference", + ":hlo_dce", + ":hlo_pass", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +tf_cc_test( + name = "dynamic_padder_test", + srcs = ["dynamic_padder_test.cc"], + deps = [ + ":dynamic_padder", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "dynamic_dimension_inference_test", srcs = ["dynamic_dimension_inference_test.cc"], @@ -2017,7 +2065,6 @@ cc_library( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -2058,6 +2105,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", @@ -2116,6 +2164,8 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -2288,6 +2338,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -2548,6 +2599,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -2592,6 +2644,7 @@ tf_cc_test( srcs = ["hlo_verifier_test.cc"], deps = [ ":hlo", + ":hlo_module_config", ":hlo_parser", ":hlo_verifier", ":layout_assignment", @@ -2599,6 +2652,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -2969,13 +3023,11 @@ cc_library( srcs = ["hlo_get_dimension_size_rewriter.cc"], hdrs = ["hlo_get_dimension_size_rewriter.h"], deps = [ + ":dynamic_dimension_inference", ":hlo", ":hlo_pass", ":shape_inference", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", @@ -3186,6 +3238,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", @@ -3346,7 +3399,6 @@ cc_library( ":hlo_pass_pipeline", "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", - "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -3403,10 +3455,39 @@ cc_library( ":hlo_profile_printer_data", ":human_readable_profile_builder", "//tensorflow/compiler/xla:types", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", ], ) +cc_library( + name = "sort_simplifier", + srcs = ["sort_simplifier.cc"], + hdrs = ["sort_simplifier.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +tf_cc_test( + name = "sort_simplifier_test", + srcs = ["sort_simplifier_test.cc"], + deps = [ + ":hlo_matchers", + ":hlo_parser", + ":pattern_matcher", + ":pattern_matcher_gmock", + ":sort_simplifier", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + cc_library( name = "tuple_util", srcs = ["tuple_util.cc"], @@ -3505,7 +3586,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", ], ) @@ -3574,14 +3654,16 @@ cc_library( tf_cc_test( name = "indexed_array_analysis_test", srcs = ["indexed_array_analysis_test.cc"], + extra_copts = ["-Wno-string-plus-int"], deps = [ ":hlo_matchers", + ":hlo_parser", ":indexed_array_analysis", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -3675,6 +3757,7 @@ cc_library( ":pattern_matcher", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -3686,6 +3769,38 @@ cc_library( ], ) +cc_library( + name = "dynamic_index_splitter", + srcs = ["dynamic_index_splitter.cc"], + hdrs = ["dynamic_index_splitter.h"], + deps = [ + ":hlo_casting_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "dynamic_index_splitter_test", + srcs = ["dynamic_index_splitter_test.cc"], + deps = [ + ":dynamic_index_splitter", + ":hlo", + ":hlo_matchers", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + tf_cc_test( name = "ar_crs_combiner_test", srcs = ["ar_crs_combiner_test.cc"], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 9e453203ce17cceb606cac06d0ebfaccbf912126..da15ff7d7a2bee8f142bacc996f7fcd063598f77 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -26,6 +26,8 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" @@ -34,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -50,6 +53,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" @@ -120,23 +124,37 @@ bool TransposeIsBitcast(const HloInstruction* transpose) { transpose->dimensions()); } -// Returns true if the given reshape/copy produces a result which is bit-wise -// identical to its operand and thus may be replaced with a bitcast. -// -// This function is conservative -- even if this function returns false, the -// reshape may still be a bitcast. For example, a reshape from [28x28] to [784]. -bool ReshapeOrCopyIsBitcast( - const HloInstruction* instr, - const AlgebraicSimplifierOptions::ValidBitcastCallback& - valid_bitcast_callback) { +// Recursive helper for method below. +HloInstruction* BitcastingOperandOfReshapeOrCopyChainHelper( + HloInstruction* instr, HloInstruction* operand, + const AlgebraicSimplifierOptions& options) { + // Can't replace chain of copies and reshapes with bitcasts if the compiler + // used a memory layout which isn't compatible. + if (options.ReshapeIsBitcast(operand->shape(), instr->shape())) { + return operand; + } + + // If the operand is a copy or reshape try to see if the operand's operand + // would produce a bitcast with initial instruction. + if (HloOpcode::kReshape == operand->opcode() || + HloOpcode::kCopy == operand->opcode()) { + return BitcastingOperandOfReshapeOrCopyChainHelper( + instr, operand->mutable_operand(0), options); + } + return nullptr; +} + +// Returns an operand of a chain of reshapes and copies that is bit-wise +// identical to first reshape or copy in the chain. +HloInstruction* BitcastingOperandOfReshapeOrCopyChain( + HloInstruction* instr, const AlgebraicSimplifierOptions& options) { + if (!options.is_layout_sensitive()) { + return nullptr; + } CHECK(HloOpcode::kReshape == instr->opcode() || HloOpcode::kCopy == instr->opcode()); - - const HloInstruction* operand = instr->operand(0); - // Can't insert bitcasts if the compiler used a memory layout which isn't - // compatible. - return ShapeUtil::ReshapeIsBitcast(operand->shape(), instr->shape()) && - valid_bitcast_callback(operand->shape(), instr->shape()); + return BitcastingOperandOfReshapeOrCopyChainHelper( + instr, instr->mutable_operand(0), options); } bool IsUnstridedSlice(const HloInstruction* hlo) { @@ -203,6 +221,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandlePower(HloInstruction* power) override; + Status HandleRemainder(HloInstruction* remainder) override; + Status HandleReshape(HloInstruction* reshape) override; Status HandleReduce(HloInstruction* reduce) override; @@ -251,7 +271,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Reshapes an instruction to rank 1 if it is not already rank 1. HloInstruction* Flatten(HloInstruction* hlo) { - if (ShapeUtil::Rank(hlo->shape()) == 1) { + if (hlo->shape().rank() == 1) { return hlo; } return computation_->AddInstruction(HloInstruction::CreateReshape( @@ -271,8 +291,11 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { shape, hlo, zero, {dim}, AddReduce_computation)); } - // Convenience method for replacing an instruction with a bitcast. - void ReplaceWithBitcast(HloInstruction* instruction); + // Convenience method for replacing an instruction with a bitcast. If operand + // is not null, then the bitcast will use the specified operand instead of the + // operand of the instruction. + void ReplaceWithBitcast(HloInstruction* instruction, + HloInstruction* operand = nullptr); // Replace old instruction with new instruction if old and new instructions // have the same shape. Updates uses and root instruction. Returns whether a @@ -401,17 +424,19 @@ bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs, } } -void AlgebraicSimplifierVisitor::ReplaceWithBitcast( - HloInstruction* instruction) { +void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction, + HloInstruction* operand) { CHECK_EQ(1, instruction->operand_count()); + if (operand == nullptr) { + operand = instruction->mutable_operand(0); + } CHECK_EQ(ShapeUtil::ElementsIn(instruction->shape()), - ShapeUtil::ElementsIn(instruction->operand(0)->shape())); + ShapeUtil::ElementsIn(operand->shape())); CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()), - ShapeUtil::ByteSizeOf(instruction->operand(0)->shape())); + ShapeUtil::ByteSizeOf(operand->shape())); - auto bitcast = computation_->AddInstruction( - HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kBitcast, - instruction->mutable_operand(0))); + auto bitcast = computation_->AddInstruction(HloInstruction::CreateUnary( + instruction->shape(), HloOpcode::kBitcast, operand)); TF_CHECK_OK(ReplaceInstruction(instruction, bitcast)); } @@ -572,9 +597,9 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { return Status::OK(); } - if (options_.is_layout_sensitive() && - ReshapeOrCopyIsBitcast(copy, options_.valid_bitcast_callback())) { - ReplaceWithBitcast(copy); + if (HloInstruction* bitcast_operand = + BitcastingOperandOfReshapeOrCopyChain(copy, options_)) { + ReplaceWithBitcast(copy, bitcast_operand); } return Status::OK(); @@ -687,7 +712,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( return Status::OK(); } PaddingConfig padding_config; - for (int64 dim = 0; dim < ShapeUtil::Rank(operands[0]->shape()); ++dim) { + for (int64 dim = 0; dim < operands[0]->shape().rank(); ++dim) { auto padding_config_dim = padding_config.add_dimensions(); padding_config_dim->set_edge_padding_high(0); padding_config_dim->set_edge_padding_low(0); @@ -715,7 +740,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( static HloInstruction* BuildTupleConstant(HloComputation* computation, const LiteralSlice& literal) { - if (ShapeUtil::IsTuple(literal.shape())) { + if (literal.shape().IsTuple()) { std::vector elems; elems.reserve(ShapeUtil::TupleElementCount(literal.shape())); for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) { @@ -732,7 +757,7 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation, Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { // Tuple constants aren't directly supported by any backend. Expand them into // explicit Tuple instructions. - if (ShapeUtil::IsTuple(constant->shape())) { + if (constant->shape().IsTuple()) { return ReplaceInstruction( constant, BuildTupleConstant(computation_, constant->literal())); } @@ -754,7 +779,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { } // If a literal is an increasing sequence from zero, replace it with an iota. - if (ShapeUtil::Rank(constant->shape()) == 1 && + if (constant->shape().rank() == 1 && ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsR1Iota()) { return ReplaceWithNewInstruction( @@ -791,6 +816,79 @@ Status InvertConstant(const HloInstruction& constant, Literal* result) { return T{1.0} / constant.literal().Get(indices); }); } + +template +std::unique_ptr TryDivideToShift(HloInstruction* divide, + HloComputation* computation) { + HloInstruction *a, *b, *c; + CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b)))); + + if (ShapeUtil::ElementIsIntegral(divide->shape()) && + !Match(b, m::ConstantEffectiveScalar(&c)) && + !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) { + return nullptr; + } + + if (ShapeUtil::ElementIsSigned(divide->shape())) { + int64 b_value = c->literal().GetFirstElement(); + if (b_value > 0 && IsPowerOfTwo(static_cast(b_value))) { + // Handle negative dividends by negating the result of the division. + HloInstruction* zero_like_a = BroadcastZeros( + computation, a->shape().element_type(), a->shape().dimensions()); + + auto* dividend_is_negative = + computation->AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::ChangeElementType(a->shape(), PRED), HloOpcode::kLt, a, + zero_like_a)); + + auto* negated_dividend = computation->AddInstruction( + HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a)); + + auto* abs_dividend = + computation->AddInstruction(HloInstruction::CreateTernary( + a->shape(), HloOpcode::kSelect, dividend_is_negative, + negated_dividend, a)); + + int log2_abs_b_value = tensorflow::Log2Floor64(b_value); + + auto* shift_amount = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(log2_abs_b_value))); + if (!ShapeUtil::IsScalar(b->shape())) { + shift_amount = computation->AddInstruction( + HloInstruction::CreateBroadcast(b->shape(), shift_amount, {})); + } + + auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kShiftRightLogical, abs_dividend, + shift_amount)); + + auto* neqated_quotient = + computation->AddInstruction(HloInstruction::CreateUnary( + quotient->shape(), HloOpcode::kNegate, quotient)); + + return HloInstruction::CreateTernary(divide->shape(), HloOpcode::kSelect, + dividend_is_negative, + neqated_quotient, quotient); + } + } else { + uint64 b_value = c->literal().GetFirstElement(); + if (IsPowerOfTwo(b_value)) { + int log2_abs_b_value = tensorflow::Log2Floor64(b_value); + HloInstruction* shift_amount = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(log2_abs_b_value))); + if (!ShapeUtil::IsScalar(b->shape())) { + shift_amount = computation->AddInstruction( + HloInstruction::CreateBroadcast(b->shape(), shift_amount, {})); + } + return HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kShiftRightLogical, a, shift_amount); + } + } + + return nullptr; +} } // namespace Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { @@ -803,6 +901,60 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return Status::OK(); } + // A / B => A >> log2(B) if B is a power of 2. + switch (divide->shape().element_type()) { + case S8: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case S16: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case S32: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case S64: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case U8: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case U16: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case U32: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case U64: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + default: + break; + } + // exp(A)/exp(B) => exp(A-B) if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b))) .WithShape(m::Shape(&shape)))) { @@ -870,6 +1022,9 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { case C64: TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); break; + case C128: + TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + break; default: return Status::OK(); } @@ -930,9 +1085,9 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( return -1; }; - const int64 dot_rank = ShapeUtil::Rank(dot->shape()); - const int64 rhs_rank = ShapeUtil::Rank(rhs->shape()); - const int64 lhs_rank = ShapeUtil::Rank(lhs->shape()); + const int64 dot_rank = dot->shape().rank(); + const int64 rhs_rank = rhs->shape().rank(); + const int64 lhs_rank = lhs->shape().rank(); const auto& dnums = dot->dot_dimension_numbers(); if (dnums.rhs_contracting_dimensions_size() > 1) { return false; @@ -1036,7 +1191,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // ) if (lhs_rank == 1 || (lhs_rank == 2 && lhs->shape().dimensions(lhs_kept_dim) == 1)) { - if (ShapeUtil::Rank(rhs->shape()) == 1) { + if (rhs->shape().rank() == 1) { TF_RETURN_IF_ERROR( ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( multiply(Flatten(lhs), rhs), 0)))); @@ -1373,6 +1528,9 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}. bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice; + HloDynamicSliceInstruction* dynamic_slice = + lhs_is_dynamic_slice ? Cast(lhs) + : Cast(rhs); // ctA: HloInstruction* left_operand = @@ -1390,8 +1548,6 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( HloInstruction::CreateDot(memoized_shape, left_operand, right_operand, dnums, dot->precision_config())); // Get pair {start, 0} or {0, start}. - HloInstruction* original_start_indices = - lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1); // Position of start: int index_of_non_zero_start = lhs_is_dynamic_slice ? 1 - lhs_contracting_dimension @@ -1400,23 +1556,19 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( int index_of_zero_start = 1 - index_of_non_zero_start; // Slice out start and 0 components and reorder if necessary. - auto indices_type = original_start_indices->shape().element_type(); + auto indices_type = dynamic_slice->operand(1)->shape().element_type(); Shape s_shape = ShapeUtil::MakeShape(indices_type, {1}); Shape d_shape = ShapeUtil::MakeShape(indices_type, {2}); HloInstruction* non_zero_start = - computation_->AddInstruction(HloInstruction::CreateSlice( - s_shape, original_start_indices, {index_of_non_zero_start}, - {index_of_non_zero_start + 1}, {1})); + dynamic_slice->mutable_operand(1 + index_of_non_zero_start); HloInstruction* zero_start = - computation_->AddInstruction(HloInstruction::CreateSlice( - s_shape, original_start_indices, {index_of_zero_start}, - {index_of_zero_start + 1}, {1})); - HloInstruction* new_start_indices = - lhs_is_dynamic_slice - ? computation_->AddInstruction(HloInstruction::CreateConcatenate( - d_shape, {non_zero_start, zero_start}, 0)) - : computation_->AddInstruction(HloInstruction::CreateConcatenate( - d_shape, {zero_start, non_zero_start}, 0)); + dynamic_slice->mutable_operand(1 + index_of_zero_start); + std::vector new_start_indices; + if (lhs_is_dynamic_slice) { + new_start_indices = {non_zero_start, zero_start}; + } else { + new_start_indices = {zero_start, non_zero_start}; + } // Build DynamicSlice(ctA x ctB). const int new_slice_m = lhs_is_dynamic_slice ? 1 : m; @@ -1449,8 +1601,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { dot->shape().element_type() != BF16) { return Status::OK(); } - if (ShapeUtil::Rank(lhs->shape()) > 2 || ShapeUtil::Rank(rhs->shape()) > 2 || - ShapeUtil::Rank(dot->shape()) > 2) { + if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 || + dot->shape().rank() > 2) { if (options_.enable_dot_strength_reduction() && !options_.is_layout_sensitive()) { TF_RETURN_IF_ERROR(HandleDotStrengthReduction(dot).status()); @@ -1686,7 +1838,7 @@ bool OutputIsPermutationOfOperandElements(HloInstruction* instruction, case HloOpcode::kTranspose: return true; case HloOpcode::kSort: - return (!ShapeUtil::IsTuple(instruction->shape())); + return (!instruction->shape().IsTuple()); default: return false; } @@ -1732,8 +1884,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { // A degenerate broadcast that has the same input and output rank can be // converted into a transpose. - if (ShapeUtil::Rank(broadcast->shape()) == - ShapeUtil::Rank(operand->shape()) && + if (broadcast->shape().rank() == operand->shape().rank() && ShapeUtil::ElementsIn(broadcast->shape()) == ShapeUtil::ElementsIn(operand->shape())) { VLOG(10) << "transform broadcast(X) -> transpose(X) where " @@ -1888,7 +2039,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { if (HasInteriorPadding(pad->padding_config())) { PaddingConfig padding_config = pad->padding_config(); bool cleared_interior_padding = false; - for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + for (int64 i = 0; i < pad->shape().rank(); ++i) { if (padding_config.dimensions(i).interior_padding() > 0 && pad->operand(0)->shape().dimensions(i) == 1) { cleared_interior_padding = true; @@ -2139,6 +2290,137 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( return changed; } +namespace { +template +std::unique_ptr TryRemainderToAnd(HloInstruction* remainder, + HloComputation* computation) { + HloInstruction *a, *b, *c; + CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b)))); + + if (ShapeUtil::ElementIsIntegral(remainder->shape()) && + !Match(b, m::ConstantEffectiveScalar(&c)) && + !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) { + return nullptr; + } + + if (ShapeUtil::ElementIsSigned(remainder->shape())) { + int64 b_value = c->literal().GetFirstElement(); + if (b_value > 0 && IsPowerOfTwo(static_cast(b_value))) { + // Handle negative dividends by negating the result of the division. + HloInstruction* zero_like_a = BroadcastZeros( + computation, a->shape().element_type(), a->shape().dimensions()); + + auto* dividend_is_negative = + computation->AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::ChangeElementType(a->shape(), PRED), HloOpcode::kLt, a, + zero_like_a)); + + auto* negated_dividend = computation->AddInstruction( + HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a)); + + auto* abs_dividend = + computation->AddInstruction(HloInstruction::CreateTernary( + a->shape(), HloOpcode::kSelect, dividend_is_negative, + negated_dividend, a)); + + auto* mask_amount = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(b_value - 1))); + if (!ShapeUtil::IsScalar(b->shape())) { + mask_amount = computation->AddInstruction( + HloInstruction::CreateBroadcast(b->shape(), mask_amount, {})); + } + + auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary( + remainder->shape(), HloOpcode::kAnd, abs_dividend, mask_amount)); + + auto* neqated_quotient = + computation->AddInstruction(HloInstruction::CreateUnary( + quotient->shape(), HloOpcode::kNegate, quotient)); + + return HloInstruction::CreateTernary( + remainder->shape(), HloOpcode::kSelect, dividend_is_negative, + neqated_quotient, quotient); + } + } else { + uint64 b_value = c->literal().GetFirstElement(); + if (IsPowerOfTwo(b_value)) { + HloInstruction* mask_amount = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(b_value - 1))); + if (!ShapeUtil::IsScalar(b->shape())) { + mask_amount = computation->AddInstruction( + HloInstruction::CreateBroadcast(b->shape(), mask_amount, {})); + } + return HloInstruction::CreateBinary(remainder->shape(), HloOpcode::kAnd, + a, mask_amount); + } + } + return nullptr; +} +} // namespace + +Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) { + HloInstruction *a, *b; + CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b)))); + + // A % B => A & (B - 1) if B is a power of 2. + switch (remainder->shape().element_type()) { + case S8: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case S16: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case S32: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case S64: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case U8: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case U16: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case U32: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case U64: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + default: + break; + } + + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { auto operand = reshape->mutable_operand(0); @@ -2195,12 +2477,10 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { } // Make this a bitcast if possible. - if (options_.is_layout_sensitive() && - ReshapeOrCopyIsBitcast(reshape, options_.valid_bitcast_callback())) { - ReplaceWithBitcast(reshape); - return Status::OK(); + if (HloInstruction* bitcast_operand = + BitcastingOperandOfReshapeOrCopyChain(reshape, options_)) { + ReplaceWithBitcast(reshape, bitcast_operand); } - return Status::OK(); } @@ -2210,8 +2490,7 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) { auto dim_is_one = [&](int64 i) -> bool { return reverse->shape().dimensions(i) == 1; }; - if (std::all_of(reverse->dimensions().begin(), reverse->dimensions().end(), - dim_is_one)) { + if (absl::c_all_of(reverse->dimensions(), dim_is_one)) { return ReplaceInstruction(reverse, reverse->mutable_operand(0)); } return Status::OK(); @@ -2276,7 +2555,7 @@ StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) { VLOG(10) << "Trying to simplify scalar slice of concat"; // Only do this for R1, there's no chance of this being useful otherwise. - if (ShapeUtil::Rank(slice->shape()) != 1) { + if (slice->shape().rank() != 1) { VLOG(10) << "Not folding, slice is not rank 1"; return false; } @@ -2326,7 +2605,7 @@ StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( return false; } HloInstruction* new_slice_operand = reshape->mutable_operand(0); - int64 slice_rank = ShapeUtil::Rank(slice->shape()); + int64 slice_rank = slice->shape().rank(); std::vector sliced_dims; for (int64 i = 0; i < slice_rank; ++i) { if (slice->slice_starts(i) != 0 || @@ -2338,7 +2617,7 @@ StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( if (sliced_dims.size() == 1 && sliced_dims[0] == 0 && slice->slice_starts(0) == 0) { const Shape& new_slice_shape = new_slice_operand->shape(); - const int64 rank = ShapeUtil::Rank(new_slice_shape); + const int64 rank = new_slice_shape.rank(); std::vector new_slice_starts(rank, 0); std::vector new_slice_stides(rank, 1); std::vector new_slice_limits(new_slice_shape.dimensions().begin(), @@ -2438,7 +2717,7 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { // TODO(b/112040122): Most of those optimizations can be done for multi-output // reduces. - if (ShapeUtil::IsTuple(reduce->shape())) { + if (reduce->shape().IsTuple()) { return Status::OK(); } @@ -2456,8 +2735,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { // A Transpose feeding a reduce can simply permute the reduction dimensions // field if the output of the reduce is a vector or scalar. Higher ranked // result may require a transpose of the output. - if (ShapeUtil::Rank(reduce->shape()) <= 1 && - arg->opcode() == HloOpcode::kTranspose) { + if (reduce->shape().rank() <= 1 && arg->opcode() == HloOpcode::kTranspose) { auto transpose_dimensions = arg->dimensions(); std::vector new_reduce_dimensions; for (auto dim : dimensions) { @@ -2487,9 +2765,9 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { // Create a new reduce with the combined reduction dimensions of both // reduces. std::vector arg_dims = arg->dimensions(); - std::sort(arg_dims.begin(), arg_dims.end()); + absl::c_sort(arg_dims); std::vector reduce_dims = reduce->dimensions(); - std::sort(reduce_dims.begin(), reduce_dims.end()); + absl::c_sort(reduce_dims); // Transform reduce_dims to the same rank as the operand of the operand. for (int64 arg_dim : arg_dims) { for (int64& dim : reduce_dims) { @@ -2516,8 +2794,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { std::vector> unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(), arg->shape()); - std::vector arg_dim_in_output(ShapeUtil::Rank(arg->shape()), true); - std::vector arg_dim_unmodified(ShapeUtil::Rank(arg->shape()), false); + std::vector arg_dim_in_output(arg->shape().rank(), true); + std::vector arg_dim_unmodified(arg->shape().rank(), false); for (auto dim : dimensions) { arg_dim_in_output[dim] = false; } @@ -2535,15 +2813,15 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { } if (can_move_reshape_into_reduce) { changed_ = true; - std::unordered_set dimensions_not_to_reduce; + absl::flat_hash_set dimensions_not_to_reduce; for (auto dim_pair : unmodified_dims) { if (arg_dim_in_output[dim_pair.second]) { dimensions_not_to_reduce.insert(dim_pair.first); } } std::vector new_reduce_dimensions; - for (int64 i = 0; i < ShapeUtil::Rank(arg->operand(0)->shape()); ++i) { - if (dimensions_not_to_reduce.count(i) == 0) { + for (int64 i = 0; i < arg->operand(0)->shape().rank(); ++i) { + if (!dimensions_not_to_reduce.contains(i)) { new_reduce_dimensions.push_back(i); } } @@ -2597,51 +2875,53 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( function)); } - // A reduce window can be expressed as a reduce and a reshape if all - // dimensions either have a window size of one or the entire dimension. If - // there is no stride, dilation, or padding, this is as easy as checking the - // size of the output shape and window dimension. - // - // The reshape is a bitcast since it adds one-sized dimensions. Often these - // ones are immediately removed as well with another reshape. The - // implementation of reduce tends to be slightly more efficient at reducing - // entire dimensions compared to reduce window. - auto effective_reduce_dims = [&] { - if (window_util::HasStride(window) || window_util::HasDilation(window) || - window_util::HasPadding(window)) { - return absl::InlinedVector{}; - } - absl::InlinedVector reduce_dims; - for (int64 i = 0; i < window.dimensions_size(); ++i) { - if (window.dimensions(i).size() == 1) { - continue; - } else if (reduce_window->shape().dimensions(i) == 1) { - reduce_dims.push_back(i); - } else { + if (options_.enable_window_reduce_to_reduce_replacement()) { + // A reduce window can be expressed as a reduce and a reshape if all + // dimensions either have a window size of one or the entire dimension. If + // there is no stride, dilation, or padding, this is as easy as checking the + // size of the output shape and window dimension. + // + // The reshape is a bitcast since it adds one-sized dimensions. Often these + // ones are immediately removed as well with another reshape. The + // implementation of reduce tends to be slightly more efficient at reducing + // entire dimensions compared to reduce window. + auto effective_reduce_dims = [&] { + if (window_util::HasStride(window) || window_util::HasDilation(window) || + window_util::HasPadding(window)) { return absl::InlinedVector{}; } - } - return reduce_dims; - }(); + absl::InlinedVector reduce_dims; + for (int64 i = 0; i < window.dimensions_size(); ++i) { + if (window.dimensions(i).size() == 1) { + continue; + } else if (reduce_window->shape().dimensions(i) == 1) { + reduce_dims.push_back(i); + } else { + return absl::InlinedVector{}; + } + } + return reduce_dims; + }(); - // If a reduce window can be expressed as a reduce, do so and reshape the - // output. - if (!effective_reduce_dims.empty()) { - Shape reduce_shape = ShapeUtil::FilterDimensions( - [&](int64 dim) { - return !absl::c_linear_search(effective_reduce_dims, dim); - }, - reduce_window->shape()); - HloInstruction* reduce = - computation_->AddInstruction(HloInstruction::CreateReduce( - /*shape=*/reduce_shape, - /*operand=*/operand, - /*init_value=*/reduce_window->mutable_operand(1), - /*dimensions_to_reduce=*/effective_reduce_dims, - /*reduce_computation=*/function)); - return ReplaceWithNewInstruction( - reduce_window, - HloInstruction::CreateReshape(reduce_window->shape(), reduce)); + // If a reduce window can be expressed as a reduce, do so and reshape the + // output. + if (!effective_reduce_dims.empty()) { + Shape reduce_shape = ShapeUtil::FilterDimensions( + [&](int64 dim) { + return !absl::c_linear_search(effective_reduce_dims, dim); + }, + reduce_window->shape()); + HloInstruction* reduce = + computation_->AddInstruction(HloInstruction::CreateReduce( + /*shape=*/reduce_shape, + /*operand=*/operand, + /*init_value=*/reduce_window->mutable_operand(1), + /*dimensions_to_reduce=*/effective_reduce_dims, + /*reduce_computation=*/function)); + return ReplaceWithNewInstruction( + reduce_window, + HloInstruction::CreateReshape(reduce_window->shape(), reduce)); + } } // This optimization folds a pad op into reduce_window. @@ -2779,7 +3059,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( // Carry out the folding of the pad into reduce_window. VLOG(10) << "Folding pad into reduce-window."; Window new_window = window; - const int64 rank = ShapeUtil::Rank(reduce_window->shape()); + const int64 rank = reduce_window->shape().rank(); TF_RET_CHECK(pad_config.dimensions_size() == rank); TF_RET_CHECK(window.dimensions_size() == rank); for (int64 i = 0; i < rank; ++i) { @@ -2828,6 +3108,7 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { return ReplaceWithNewInstruction( sort, HloInstruction::CreateTuple(sort->operands())); } + if (!options_.enable_permutation_sort_replacement()) { return Status::OK(); } @@ -2862,7 +3143,7 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { // - Use this as the indices parameter of scatter, and set updates // of the scatter to be a reshaped 'values' parameter of sort (adding // 'rank' many 1 dimensions at the end). - int64 rank = ShapeUtil::Rank(operand->shape()); + int64 rank = operand->shape().rank(); Shape extended_shape = operand->shape(); extended_shape.add_dimensions(1); extended_shape.mutable_layout()->add_minor_to_major(rank); @@ -3221,15 +3502,6 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( const Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout( convolution_shape.element_type(), {conv_width, output_channels}); - // We cannot insert bitcasts if the layouts will not be compatible. - // TODO(b/33178038): Consider inserting a transpose if a bitcast would be - // invalid. - if (!options_.valid_bitcast_callback()(input_shape, new_input_shape) || - !options_.valid_bitcast_callback()(filter_shape, new_filter_shape) || - !options_.valid_bitcast_callback()(dot_output_shape, convolution_shape)) { - return false; - } - auto new_lhs = add_bitcast(new_input_shape, lhs); auto new_rhs = add_bitcast(new_filter_shape, rhs); DotDimensionNumbers dot_dimension_numbers; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index d2775b9fafa7e4c625f5d181114e80e7369f9c78..ff3f638b22e290f6f6237a5a72a257aa23ecd78b 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -25,21 +25,25 @@ namespace xla { class AlgebraicSimplifierOptions { public: - // Given shapes 'from_shape' and 'to_shape', determines if it is valid to - // bitcast from 'from_shape' to 'to_shape' after considering platform - // dependent effects on layout like alignment restrictions. Precondition: the - // two shapes have layouts, the same number of elements and - // ShapeUtil::ReshapeIsBitcast returns true. - using ValidBitcastCallback = + AlgebraicSimplifierOptions() {} + // Platform dependent callback to determine if a reshape `from_shape` to + // `to_shape` is a bitcast. + using ReshapeIsBitcastCallback = std::function; - explicit AlgebraicSimplifierOptions( - ValidBitcastCallback valid_bitcast_callback) - : valid_bitcast_callback_(std::move(valid_bitcast_callback)) {} - // If valid_bitcast_callback returns true, then the pass will replace reshapes - // and transposes with bitcasts. - const ValidBitcastCallback& valid_bitcast_callback() const { - return valid_bitcast_callback_; + ReshapeIsBitcastCallback reshape_is_bitcast_callback) + : reshape_is_bitcast_callback_(std::move(reshape_is_bitcast_callback)) {} + + // Use the platform specific callback if set. It is not sensible to return + // true here if the options are not layout sensitive. + bool ReshapeIsBitcast(const Shape& from_shape, const Shape& to_shape) const { + if (!is_layout_sensitive_) { + return false; + } + if (!reshape_is_bitcast_callback_) { + return ShapeUtil::ReshapeIsBitcast(from_shape, to_shape); + } + return reshape_is_bitcast_callback_(from_shape, to_shape); } // If is_layout_sensitive is true, then the simplifier preserves layout during @@ -47,12 +51,14 @@ class AlgebraicSimplifierOptions { void set_is_layout_sensitive(bool is_layout_sensitive) { is_layout_sensitive_ = is_layout_sensitive; } + bool is_layout_sensitive() const { return is_layout_sensitive_; } // Enable dot simplification on platforms where it is profitable. void set_enable_dot_strength_reduction(bool enable_dot_strength_reduction) { enable_dot_strength_reduction_ = enable_dot_strength_reduction; } + bool enable_dot_strength_reduction() const { return enable_dot_strength_reduction_; } @@ -71,16 +77,30 @@ class AlgebraicSimplifierOptions { bool enable_permutation_sort_replacement) { enable_permutation_sort_replacement_ = enable_permutation_sort_replacement; } + bool enable_permutation_sort_replacement() const { return enable_permutation_sort_replacement_; } + // If enable_window_reduce_replacement is true, the kReduceWindow instruction + // can be optimized by replacement with simpler operations. + void set_enable_window_reduce_to_reduce_replacement( + bool enable_window_reduce_to_reduce_replacement) { + enable_window_reduce_to_reduce_replacement_ = + enable_window_reduce_to_reduce_replacement; + } + + bool enable_window_reduce_to_reduce_replacement() const { + return enable_window_reduce_to_reduce_replacement_; + } + private: - ValidBitcastCallback valid_bitcast_callback_; + ReshapeIsBitcastCallback reshape_is_bitcast_callback_; bool is_layout_sensitive_{false}; bool enable_dot_strength_reduction_{true}; bool enable_conv_simplification_{true}; bool enable_permutation_sort_replacement_{false}; + bool enable_window_reduce_to_reduce_replacement_{true}; }; // A pass which performs algebraic simplifications. diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index a9d617cbf6dcd02283d5d66655c0fa6ddf6dc27f..3602ab82b248bb3d7cd8203ed7664e3c460374d2 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -46,17 +46,9 @@ namespace { using ::testing::ElementsAre; namespace m = match; -AlgebraicSimplifierOptions::ValidBitcastCallback bitcasting_callback() { - return [](const Shape&, const Shape&) { return true; }; -} - -AlgebraicSimplifierOptions::ValidBitcastCallback non_bitcasting_callback() { - return [](const Shape&, const Shape&) { return false; }; -} - class AlgebraicSimplifierTest : public HloTestBase { protected: - AlgebraicSimplifierOptions default_options_{non_bitcasting_callback()}; + AlgebraicSimplifierOptions default_options_; }; // Test that A + 0 is simplified to A @@ -202,6 +194,86 @@ TEST_F(AlgebraicSimplifierTest, FactorFpAdditionBfloat16) { m::Broadcast(m::ConstantScalar(0.125))))); } +TEST_F(AlgebraicSimplifierTest, UnsignedDivideByPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p = u32[4] parameter(0) + c = u32[] constant(8) + b = u32[4] broadcast(c), dimensions={} + ROOT d = u32[4] divide(p, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::ShiftRightLogical( + m::Parameter(0), m::Broadcast(m::ConstantScalar(3))))); +} + +TEST_F(AlgebraicSimplifierTest, SignedDivideByPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p = s32[4] parameter(0) + c = s32[] constant(8) + b = s32[4] broadcast(c), dimensions={} + ROOT d = s32[4] divide(p, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto match_dividend_is_negative = + m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0))); + auto match_abs = m::Select(match_dividend_is_negative, + m::Negate(m::Parameter(0)), m::Parameter(0)); + auto match_shift = + m::ShiftRightLogical(match_abs, m::Broadcast(m::ConstantScalar(3))); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Select(match_dividend_is_negative, + m::Negate(match_shift), match_shift))); +} + +TEST_F(AlgebraicSimplifierTest, UnsignedRemainderByPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p = u32[4] parameter(0) + c = u32[] constant(8) + b = u32[4] broadcast(c), dimensions={} + ROOT r = u32[4] remainder(p, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::AndAnyOrder(m::Parameter(0), + m::Broadcast(m::ConstantScalar(7))))); +} + +TEST_F(AlgebraicSimplifierTest, SignedRemainderByPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p = s32[4] parameter(0) + c = s32[] constant(8) + b = s32[4] broadcast(c), dimensions={} + ROOT r = s32[4] remainder(p, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto match_dividend_is_negative = + m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0))); + auto match_abs = m::Select(match_dividend_is_negative, + m::Negate(m::Parameter(0)), m::Parameter(0)); + auto match_and = + m::AndAnyOrder(match_abs, m::Broadcast(m::ConstantScalar(7))); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Select(match_dividend_is_negative, + m::Negate(match_and), match_and))); +} + // Test that A * 0 is simplified to 0 TEST_F(AlgebraicSimplifierTest, MulZero) { auto m = CreateNewVerifiedModule(); @@ -1464,23 +1536,77 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) { EXPECT_THAT(computation->root_instruction(), param0); } -TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { +TEST_F(AlgebraicSimplifierTest, CopyOfReshapeOfCopyEqualsBitcast) { auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 14, 14, 64}), "param")); - *param->mutable_shape()->mutable_layout() = - LayoutUtil::MakeLayout({0, 1, 2, 3}); + 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}), + "param")); HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeShape(F32, {1, 14, 14, 64}), HloOpcode::kCopy, param)); - *copy->mutable_shape()->mutable_layout() = - LayoutUtil::MakeLayout({1, 2, 0, 3}); + ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}), + HloOpcode::kCopy, param)); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {0, 1}), copy)); + builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {1, 0}), + HloOpcode::kCopy, reshape)); + auto computation = m->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Reshape(m::Copy(m::Parameter(0)))))); + + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + // Verify that the copy of reshape of copy is replaced. + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); +} + +TEST_F(AlgebraicSimplifierTest, ReshapeOfCopyEqualsBitcast) { + auto m = CreateNewVerifiedModule(); + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}), + "param")); + HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}), + HloOpcode::kCopy, param)); + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {1, 0}), copy)); + + auto computation = m->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Copy(m::Parameter(0))))); + + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + // Verify that the copy of reshape of copy is replaced. + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); +} + +TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { + auto m = CreateNewVerifiedModule(); + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}), + "param")); + builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {1, 2, 0, 3}), + HloOpcode::kCopy, param)); auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifierOptions options(non_bitcasting_callback()); + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return false; }); options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier1(options); ASSERT_FALSE(simplifier1.Run(m.get()).ValueOrDie()); @@ -1488,10 +1614,10 @@ TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifierOptions options2(bitcasting_callback()); + AlgebraicSimplifierOptions options2; options2.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier2(options2); - ASSERT_TRUE(simplifier2.Run(m.get()).ValueOrDie()); + EXPECT_TRUE(simplifier2.Run(m.get()).ValueOrDie()); // Verify that the copy is replaced. EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Bitcast(m::Parameter(0)))); @@ -1744,7 +1870,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifierOptions options(non_bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); @@ -1774,7 +1900,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifierOptions options(non_bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -1804,7 +1930,8 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Reshape(m::Parameter(0)))); - AlgebraicSimplifierOptions options(non_bitcasting_callback()); + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return false; }); options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); @@ -1835,8 +1962,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier( - (AlgebraicSimplifierOptions(bitcasting_callback()))); + AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{}); EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that reshape(transpose(rng)) is replace by a single rng of the @@ -1887,7 +2013,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { m::Op().Is(dimensions_wrong_reshape), m::Op().Is(layout_wrong_reshape)))); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); simplifier.Run(m.get()).ValueOrDie(); @@ -1917,8 +2043,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add)); - AlgebraicSimplifier simplifier( - (AlgebraicSimplifierOptions(bitcasting_callback()))); + AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{}); m->AddEntryComputation(builder.Build()); EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } @@ -1942,8 +2067,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add, /*broadcast_dimensions=*/{0, 1})); - AlgebraicSimplifier simplifier( - (AlgebraicSimplifierOptions(bitcasting_callback()))); + AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{}); m->AddEntryComputation(builder.Build()); EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } @@ -1968,7 +2092,7 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Transpose(m::Parameter(0)))); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -1998,7 +2122,7 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Transpose(m::Parameter(0)))); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -2055,7 +2179,7 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Copy(m::Copy(m::Parameter(0))))); - AlgebraicSimplifierOptions options(non_bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -2103,9 +2227,8 @@ TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) { ROOT reshaped_again = f32[10] reshape(f32[10,1,1] transposed) } )"; - TF_ASSERT_OK_AND_ASSIGN( - auto module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); HloPassFix simplifier(default_options_); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); @@ -2651,7 +2774,7 @@ TEST_F(AlgebraicSimplifierTest, ReplacePermutationSortWithScatter) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(non_bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_enable_permutation_sort_replacement(true); AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); @@ -2680,7 +2803,7 @@ TEST_F(AlgebraicSimplifierTest, DontReplacePermutationSortIfNonIntegral) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(non_bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_enable_permutation_sort_replacement(true); AlgebraicSimplifier simplifier(options); EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); @@ -2703,7 +2826,7 @@ TEST_F(AlgebraicSimplifierTest, DontReplacePermutationSortWrongDimensions) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(non_bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_enable_permutation_sort_replacement(true); AlgebraicSimplifier simplifier(options); EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); @@ -2945,7 +3068,7 @@ class ConvInputPaddingTest : public AlgebraicSimplifierTest, public ::testing::WithParamInterface {}; -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( ConvInputPaddingTestCases, ConvInputPaddingTest, ::testing::ValuesIn(std::vector{ // Merge this edge padding into the conv. @@ -3053,7 +3176,7 @@ class ConvFilterPaddingTest : public AlgebraicSimplifierTest, public ::testing::WithParamInterface {}; -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( ConvFilterPaddingTestCases, ConvFilterPaddingTest, ::testing::ValuesIn(std::vector{ // Can only merge interior padding on the filter's spatial dimensions; @@ -3292,7 +3415,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(b.Build()); - AlgebraicSimplifierOptions simplifier_options(bitcasting_callback()); + AlgebraicSimplifierOptions simplifier_options; simplifier_options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(simplifier_options); if (!simplifier.Run(module.get()).ValueOrDie()) { @@ -3498,7 +3621,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Create the reduce-window. Window window; - for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + for (int64 i = 0; i < pad->shape().rank(); ++i) { auto* dim = window.add_dimensions(); dim->set_size(1); dim->set_padding_low(10); @@ -3584,7 +3707,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { // Create the reduce-window. Window window; - for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + for (int64 i = 0; i < pad->shape().rank(); ++i) { auto* dim = window.add_dimensions(); dim->set_size(1); dim->set_padding_low(10); @@ -3706,12 +3829,16 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); + std::vector params; + for (int i = 0; i < 3; ++i) { + params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( + i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices"))); + } builder.AddInstruction(HloInstruction::CreateDynamicSlice( shape, builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "slice_from")), - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), + params, /*slice_sizes=*/{10, 100, 1000})); auto computation = m->AddEntryComputation(builder.Build()); @@ -3730,28 +3857,35 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000}); + std::vector slice_indices, update_indices; + for (int i = 0; i < 3; ++i) { + slice_indices.push_back( + builder.AddInstruction(HloInstruction::CreateParameter( + i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices"))); + update_indices.push_back( + builder.AddInstruction(HloInstruction::CreateParameter( + i + 5, ShapeUtil::MakeShape(U32, {}), "update_indices"))); + } HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( slice_shape, builder.AddInstruction( HloInstruction::CreateParameter(0, full_shape, "slice_from")), - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), + slice_indices, /*slice_sizes=*/{10, 1, 1000})); builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( slice_shape, builder.AddInstruction( - HloInstruction::CreateParameter(2, slice_shape, "to_update")), - slice, - builder.AddInstruction(HloInstruction::CreateParameter( - 3, ShapeUtil::MakeShape(U32, {3}), "update_indices")))); + HloInstruction::CreateParameter(4, slice_shape, "to_update")), + slice, update_indices)); auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - GmockMatch(m::DynamicSlice(m::Parameter(), m::Parameter()))); + GmockMatch(m::DynamicSlice(m::Parameter(), m::Parameter(), + m::Parameter(), m::Parameter()))); } // Test that two consecutive broadcasts can be merged to one. @@ -3858,7 +3992,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); @@ -3879,7 +4013,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); @@ -3900,7 +4034,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } @@ -3919,7 +4053,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); @@ -3941,7 +4075,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); @@ -3963,7 +4097,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); @@ -3985,7 +4119,7 @@ TEST_F(AlgebraicSimplifierTest, NegateNegate) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); @@ -4005,7 +4139,7 @@ TEST_F(AlgebraicSimplifierTest, NotNot) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); @@ -4142,7 +4276,7 @@ PadReduceWindowEffectiveBroadcastCases() { return *cases; } -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( PadReduceWindowEffectiveBroadcastInstantiation, PadReduceWindowEffectiveBroadcastTest, ::testing::ValuesIn(PadReduceWindowEffectiveBroadcastCases())); @@ -4193,7 +4327,7 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { EXPECT_EQ(has_no_dot, dot_should_be_transformed); } -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest, ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Values(F32, BF16))); @@ -4250,7 +4384,7 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) { EXPECT_EQ(has_no_dot, dot_should_be_transformed); } -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( DotStrengthReductionTestInstantiation, DotStrengthReductionTest, ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Bool(), @@ -4412,9 +4546,10 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { HloInstruction* const update = builder.AddInstruction( HloInstruction::CreateParameter(1, update_shape, "update")); HloInstruction* const start_indices = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0({}))); builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - dslice_shape, operand, update, start_indices)); + dslice_shape, operand, update, + std::initializer_list({start_indices}))); const HloComputation* const computation = m->AddEntryComputation(builder.Build()); @@ -4423,9 +4558,9 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { EXPECT_THAT(computation->root_instruction(), operand); } -INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation, - DotOfConcatSimplificationTest, - ::testing::ValuesIn(kDotOfConcatTestSpecs)); +INSTANTIATE_TEST_SUITE_P(DotOfConcatSimplificationTestInstantiation, + DotOfConcatSimplificationTest, + ::testing::ValuesIn(kDotOfConcatTestSpecs)); struct DotOfGatherTestSpec { int64 m; @@ -4467,14 +4602,17 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { int32 start_row = (spec.lcd == 0) ? 0 : spec.s; int32 start_col = (spec.lcd == 0) ? spec.s : 0; - const auto start_indices = + std::vector start_indices = { builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({start_row, start_col}))); + LiteralUtil::CreateR0(start_row))), + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(start_col)))}; int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1; int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k; - Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); + std::vector slice_sizes = {slice_row_size, slice_col_size}; + Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes); auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - ds_shape, lhs, start_indices, {slice_row_size, slice_col_size})); + ds_shape, lhs, start_indices, slice_sizes)); int64 rhs_rows = (spec.rcd == 0) ? spec.k : spec.n; int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k; @@ -4507,7 +4645,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { } else { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()), - m::Concatenate()))); + m::Constant(), m::Constant()))); } } @@ -4545,14 +4683,17 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { int32 start_row = (spec.rcd == 0) ? 0 : spec.s; int32 start_col = (spec.rcd == 0) ? spec.s : 0; - const auto start_indices = + std::vector start_indices = { + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(start_row))), builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({start_row, start_col}))); + LiteralUtil::CreateR0(start_col)))}; int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1; int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k; - Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); + std::vector slice_sizes = {slice_row_size, slice_col_size}; + Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes); auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - ds_shape, rhs, start_indices, {slice_row_size, slice_col_size})); + ds_shape, rhs, start_indices, slice_sizes)); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(spec.lcd); @@ -4577,7 +4718,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { } else { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()), - m::Concatenate()))); + m::Constant(), m::Constant()))); } } @@ -4625,7 +4766,7 @@ std::vector DotOfGatherPositiveNegativeTests() { return all; } -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest, ::testing::ValuesIn(DotOfGatherPositiveNegativeTests())); diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index ef5e211646e7b0b66b8e6c09948be58063422943..6cb0e985e57016e5a22fba50c3e3ad6970f1b178 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -142,13 +142,13 @@ StatusOr> AllocationTracker::DeconstructTuple( // We only need to care about replica id 0 here, since the GlobalDataHandle is // the same for all buffers across replicas. const ShapedBuffer* shaped_buffer = replicated_buffers[0]; - if (!ShapeUtil::IsTuple(shaped_buffer->on_host_shape())) { + if (!shaped_buffer->on_host_shape().IsTuple()) { return InvalidArgument("global data handle %d is not a tuple", data.handle()); } // If the on-host representation is a tuple, then the on-device one should be // as well. - TF_RET_CHECK(ShapeUtil::IsTuple(shaped_buffer->on_device_shape())); + TF_RET_CHECK(shaped_buffer->on_device_shape().IsTuple()); if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) { return Unimplemented("Deconstructing nested tuples is not implemented."); diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index 47d2c7e35705698d49950c2fa042af1c6327d521..f8dff6a700cc9d5843053e3d451a7b005539ca26 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -36,19 +37,34 @@ namespace { namespace m = match; -// Returns true iff the argument instruction is an AllReduce, followed by a -// certain sequence of instructions and then a CRS. It must be possible to move -// the AR past each instruction in the sequence. -bool MatchesArCrsPattern(HloInstruction* instruction) { +// Checks if the argument instruction is an AllReduce, followed by a certain +// sequence of instructions and then a CRS. It must be possible to move +// the AR past each instruction in the sequence. Returns the CRS, which is the +// last instruction in the sequence. +absl::optional MatchesArCrsPattern( + HloInstruction* instruction) { auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool { if (instruction->user_count() != 1) { return false; } - auto opcode = instruction->opcode(); - return opcode == HloOpcode::kBitcast || opcode == HloOpcode::kTranspose || - opcode == HloOpcode::kReshape || opcode == HloOpcode::kConvert || - opcode == HloOpcode::kAdd || opcode == HloOpcode::kSubtract || - opcode == HloOpcode::kMultiply; + switch (instruction->opcode()) { + case HloOpcode::kBitcast: + case HloOpcode::kTranspose: + case HloOpcode::kReshape: + return true; + case HloOpcode::kConvert: + // Can be moved across if both input and output is either float or + // integer (e.g. S32<->U32 or F32<->BF16) + return ShapeUtil::ElementIsFloating(instruction->shape()) == + ShapeUtil::ElementIsFloating(instruction->operand(0)->shape()); + case HloOpcode::kAdd: + case HloOpcode::kSubtract: + case HloOpcode::kMultiply: + // Only supported for floating point operands. + return ShapeUtil::ElementIsFloating(instruction->shape()); + default: + return false; + } }; auto computation_is_addition = [](HloComputation* c) { @@ -59,17 +75,22 @@ bool MatchesArCrsPattern(HloInstruction* instruction) { if (!instruction->IsCrossModuleAllReduce() || !computation_is_addition(instruction->called_computations()[0]) || instruction->user_count() != 1) { - return false; + return absl::nullopt; } auto next = instruction->users()[0]; while (!next->IsCrossReplicaAllReduce()) { if (can_ar_move_past_instruction(next)) { next = next->users()[0]; } else { - return false; + return absl::nullopt; } } - return computation_is_addition(next->called_computations()[0]); + if (!Cast(next)->IsNoop() && + computation_is_addition(next->called_computations()[0])) { + return absl::optional(next); + } else { + return absl::nullopt; + } } } // namespace @@ -85,7 +106,7 @@ absl::optional ArCrsCombiner::WhileFromBodyParameter( return caller_instruction; } } - return absl::optional(); + return absl::nullopt; } std::vector ArCrsCombiner::GetAllTuples( @@ -176,6 +197,15 @@ bool ArCrsCombiner::InstructionsComputeSameValue( if (opcode1 != i2->opcode() || operands1.size() != i2->operands().size()) { return false; } + auto eq_computations = [](const HloComputation* a, const HloComputation* b) { + return *a == *b; + }; + if (i1->IsCrossModuleAllReduce()) { + return i1->Identical(*i2, + /*eq_operands=*/std::equal_to(), + eq_computations, + /*layout_sensitive=*/false); + } visited_pairs->emplace(min_uid, max_uid); for (int i = 0; i < operands1.size(); ++i) { auto operand1 = operands1[i]; @@ -201,9 +231,6 @@ bool ArCrsCombiner::InstructionsComputeSameValue( // InstructionsComputeSameValue earlier. auto eq_instructions = [](const HloInstruction* i1, const HloInstruction* i2) -> bool { return true; }; - auto eq_computations = [](const HloComputation* a, const HloComputation* b) { - return *a == *b; - }; return i1->Identical(*i2, eq_instructions, eq_computations, /*layout_sensitive=*/false); } @@ -211,8 +238,14 @@ bool ArCrsCombiner::InstructionsComputeSameValue( void ArCrsCombiner::GroupAllReducesById(HloModule* module) { for (HloComputation* computation : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : computation->instructions()) { - if (MatchesArCrsPattern(instruction)) { - all_reduce_map_[*(instruction->all_reduce_id())].push_back(instruction); + auto maybe_crs = MatchesArCrsPattern(instruction); + if (maybe_crs) { + auto crs = *maybe_crs; + int64 ar_id = *(instruction->all_reduce_id()); + if (crs_reserved_map_.find(crs) == crs_reserved_map_.end()) { + all_reduce_map_[ar_id].push_back(instruction); + crs_reserved_map_[crs] = ar_id; + } } } } @@ -229,14 +262,17 @@ void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { auto next_0 = instr_0->users()[0]; auto next_i = instr_i->users()[0]; absl::flat_hash_map visited_pairs; - do { + while (true) { if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) { all_reduce_map_.erase(all_reduce_id); break; } + if (next_0->IsCrossReplicaAllReduce()) { + break; + } next_0 = next_0->users()[0]; next_i = next_i->users()[0]; - } while (!next_0->IsCrossReplicaAllReduce()); + } } } } diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.h b/tensorflow/compiler/xla/service/ar_crs_combiner.h index 6f54b97615b270bc6b180dd47d9aff6473752b47..e61ef5d4f9072979a6c356a9456c91e19405b01e 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.h +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.h @@ -83,6 +83,11 @@ class ArCrsCombiner : public HloModulePass { // Map from all-reduce ids to the all reduce instructions. absl::flat_hash_map> all_reduce_map_; + // Map from a CRS instruction to the all-reduce ID of the AR paired with the + // CRS. Sometimes, several ARs in the code could be paired with the same CRS. + // We use this map to pick a single AR/CRS path to rewrite. + absl::flat_hash_map crs_reserved_map_; + std::unique_ptr call_graph_; }; diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index caa57296f465698eb70d7cb8327d4678f394b323..5152f0dc884a153f9b0ade06acd479832d87ff25 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -360,6 +360,7 @@ HloModule foobar ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { %p = bf16[] parameter(0) + %constant.bf16 = bf16[] constant(1) %all-reduce.ar.1 = bf16[] all-reduce(%p), @@ -377,7 +378,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { sharding={maximal device=0} %all-reduce.ar.2 = bf16[] - all-reduce(%p), + all-reduce(%constant.bf16), replica_groups={{0},{1}}, all_reduce_id=1, to_apply=%sum.bf16, @@ -407,7 +408,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { EXPECT_TRUE(changed); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Tuple(op::AllReduce(op::Convert(op::Parameter())), - op::AllReduce(op::Convert(op::Parameter())))); + op::AllReduce(op::Convert(op::Constant())))); auto crs_after = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_after = crs_after->replica_groups(); @@ -705,5 +706,470 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { EXPECT_FALSE(changed); } +TEST_F(ArCrsCombinerTest, ArThenCrsDontCrash) { + const char* module_str = R"( +HloModule foobar + +%sum.1 (a: f32[], b: f32[]) -> f32[] { + %a = f32[] parameter(0) + %b = f32[] parameter(1) + ROOT %add = f32[] add(%a, %b) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.f32 = f32[] constant(123) + + %all-reduce.ar.1 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.1, + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%all-reduce.ar.1), + replica_groups={{0,1}}, + to_apply=%sum.1, + sharding={maximal device=0} + %multiply.1 = f32[] + multiply(%all-reduce.1, %constant.f32), + sharding={maximal device=0} + + %all-reduce.ar.2 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.1, + sharding={maximal device=1} + %all-reduce.2 = f32[] + all-reduce(%all-reduce.ar.2), + replica_groups={{0,1}}, + to_apply=%sum.1, + sharding={maximal device=1} + %multiply.2 = f32[] + multiply(%all-reduce.2, %constant.f32), + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Parameter()), + op::AllReduce(op::Parameter()))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteMultipleAdds) { + const char* module_str = R"( +HloModule foobar + +%sum (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.1 = f32[] constant(1) + %constant.2 = f32[] constant(2) + + %all-reduce.ar.1 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=0} + %add.11 = f32[] + add(%constant.1, %all-reduce.ar.1), + sharding={maximal device=0} + %add.12 = f32[] + add(%constant.2, %add.11), + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%add.12), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=0} + + %all-reduce.ar.2 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=0} + %add.21 = f32[] + add(%constant.1, %all-reduce.ar.2), + sharding={maximal device=0} + %add.22 = f32[] + add(%constant.2, %add.21), + sharding={maximal device=0} + %all-reduce.2 = f32[] + all-reduce(%add.22), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=0} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Divide(op::Constant(), op::Constant()), + op::Add(op::Divide(op::Constant(), op::Constant()), + op::Parameter()))), + op::AllReduce(op::Add( + op::Divide(op::Constant(), op::Constant()), + op::Add(op::Divide(op::Constant(), op::Constant()), + op::Parameter()))))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteArSubtractCrs) { + const char* module_str = R"( +HloModule foobar + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.f32 = f32[] constant(123) + + %all-reduce.ar.1 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.f32, + sharding={maximal device=0} + %sub.1 = f32[] + subtract(%constant.f32, %all-reduce.ar.1), + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%sub.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %all-reduce.ar.2 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.f32, + sharding={maximal device=1} + %sub.2 = f32[] + subtract(%constant.f32, %all-reduce.ar.2), + sharding={maximal device=1} + %all-reduce.2 = f32[] + all-reduce(%sub.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple( + op::AllReduce(op::Subtract(op::Divide(op::Constant(), op::Constant()), + op::Parameter())), + op::AllReduce(op::Subtract(op::Divide(op::Constant(), op::Constant()), + op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeft) { + const char* module_str = R"( +HloModule foobar + +%sum (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %const1 = f32[] constant(1) + %const2 = f32[] constant(2) + + %ar11 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=0} + %add11 = f32[] + add(%ar11, %const1), + sharding={maximal device=0} + %ar12 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=2, + to_apply=%sum, + sharding={maximal device=0} + %add12 = f32[] + add(%add11, %ar12), + sharding={maximal device=0} + %crs1 = f32[] + all-reduce(%add12), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=0} + + %ar21 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=1} + %add21 = f32[] + add(%ar21, %const1), + sharding={maximal device=1} + %ar22 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=2, + to_apply=%sum, + sharding={maximal device=1} + %add22 = f32[] + add(%add21, %ar22), + sharding={maximal device=1} + %crs2 = f32[] + all-reduce(%add22), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%crs1, %crs2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Add(op::Parameter(), + op::Divide(op::Constant(), op::Constant())), + op::Divide(op::AllReduce(), op::Constant()))), + op::AllReduce(op::Add( + op::Add(op::Parameter(), + op::Divide(op::Constant(), op::Constant())), + op::Divide(op::AllReduce(), op::Constant()))))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteMultipleARsRight) { + const char* module_str = R"( +HloModule foobar + +%sum (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %const1 = f32[] constant(1) + %const2 = f32[] constant(2) + + %ar11 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=0} + %ar12 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=2, + to_apply=%sum, + sharding={maximal device=0} + %add11 = f32[] + add(%ar12, %const1), + sharding={maximal device=0} + %add12 = f32[] + add(%ar11, %add11), + sharding={maximal device=0} + %crs1 = f32[] + all-reduce(%add12), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=0} + + %ar21 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=1} + %ar22 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=2, + to_apply=%sum, + sharding={maximal device=1} + %add21 = f32[] + add(%ar22, %const1), + sharding={maximal device=1} + %add22 = f32[] + add(%ar21, %add21), + sharding={maximal device=1} + %crs2 = f32[] + all-reduce(%add22), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%crs1, %crs2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Parameter(), + op::Divide(op::Add(op::AllReduce(), op::Constant()), + op::Constant()))), + op::AllReduce(op::Add( + op::Parameter(), + op::Divide(op::Add(op::AllReduce(), op::Constant()), + op::Constant()))))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, OneReplicaDontRewrite) { + const char* module_str = R"( +HloModule foobar + +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { + %p = bf16[] parameter(0) + %constant.bf16 = bf16[] constant(1) + + %all-reduce.ar.1 = bf16[] + all-reduce(%p), + replica_groups={{0}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=0} + %convert.1 = f32[] + convert(%all-reduce.ar.1), + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%convert.1), + replica_groups={{0}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %all-reduce.ar.2 = bf16[] + all-reduce(%constant.bf16), + replica_groups={{0}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=1} + %convert.2 = f32[] + convert(%all-reduce.ar.2), + sharding={maximal device=1} + %all-reduce.2 = f32[] + all-reduce(%convert.2), + replica_groups={{0}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_FALSE(changed); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 2cf24a9dd5fa18abe9dde4eb49b03c6586bfef03..215e8ced4bb3f98a26ac4eb9912a7fd4d917852f 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -115,12 +115,10 @@ StatusOr Backend::BorrowStream(int device_ordinal) { StatusOr Backend::BorrowStream(se::StreamExecutor* executor) { tensorflow::mutex_lock l(mu_); - if (0 == stream_pools_.count(executor)) { - stream_pools_.emplace(std::piecewise_construct, - std::forward_as_tuple(executor), - std::forward_as_tuple()); + if (!stream_pools_.contains(executor)) { + stream_pools_.emplace(executor, absl::make_unique()); } - return stream_pools_.at(executor).BorrowStream(executor); + return stream_pools_.at(executor)->BorrowStream(executor); } Backend::Backend(se::Platform* platform, Compiler* compiler, diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index 7ca993fb2656037951d98d9c4459a3c3e4c64c61..c35f033dc0180409ae3888c2050021da83f5c72a 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -175,7 +176,8 @@ class Backend { tensorflow::mutex mu_; // Mapping from stream executor to stream pools, used by `BorrowStream` above. - std::map stream_pools_ GUARDED_BY(mu_); + absl::flat_hash_map> + stream_pools_ GUARDED_BY(mu_); // The default memory allocator to use. std::unique_ptr memory_allocator_; diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 0e6ca1871b379a2f55b92207133822fc6258b007..e5f5c3edb2ac0c217317fbf809463aa31af9af59 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -123,7 +123,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { auto elements_per_feature_u32 = add_instruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); - for (int64 i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { + for (int64 i = 0; i < operand->shape().rank(); ++i) { if (i == feature_index) { continue; } @@ -229,7 +229,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {})); std::vector dimensions_without_feature; - for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { + for (int64 i = 0; i < operand_shape.rank(); ++i) { if (i != feature_index) { dimensions_without_feature.push_back(i); } @@ -357,7 +357,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( std::vector dimensions_without_feature; - for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { + for (int64 i = 0; i < operand_shape.rank(); ++i) { if (i != feature_index) { dimensions_without_feature.push_back(i); } @@ -494,7 +494,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( std::vector dimensions_without_feature; - for (int64 i = 0; i < ShapeUtil::Rank(activation_shape); ++i) { + for (int64 i = 0; i < activation_shape.rank(); ++i) { if (i != feature_index) { dimensions_without_feature.push_back(i); } diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index 6caef77ed00909040a54e65651cc6fb7ca74eb90..e62d72b323bd1d113e9d87bf8602bfb434c40d61 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -190,7 +190,7 @@ Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) { } // If the output is not a tuple, we don't need special handling. - if (!ShapeUtil::IsTuple(crs->shape())) { + if (!crs->shape().IsTuple()) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index e3aefe906739b74e887f33d2ffc3ad7a60510b5b..d1b14d604f0559b6b18f7d1fba127669c241c8a3 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -363,7 +363,7 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { // TODO(b/112040122): Correctly normalize variadic reduce. if ((hlo->opcode() == HloOpcode::kSort || hlo->opcode() == HloOpcode::kAllReduce) && - ShapeUtil::IsTuple(hlo->shape())) { + hlo->shape().IsTuple()) { return HandleMultipleOutputs(hlo); } return HandleInstruction(hlo); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 05dd4b3e914f5563a33d534829ffb01668279064..bab63f66d83b712d756078bef84926eed235f6b5 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -277,7 +277,7 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, *use.instruction, use.operand_number)) { if (use.instruction->opcode() == HloOpcode::kTuple || (use.instruction->opcode() == HloOpcode::kAllReduce && - ShapeUtil::IsTuple(use.instruction->shape()))) { + use.instruction->shape().IsTuple())) { ShapeIndex use_output_index{use.operand_number}; for (int64 i : use.operand_index) { use_output_index.push_back(i); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 202e45e181d13621f79e3bf95e33091b54e8b779..e1b91b500191c7756f3d1a4b160a0dd1e09cfe7d 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -86,10 +86,9 @@ std::vector ColorInterferenceGraph( // first, but it would be good to investigate other ordering heuristics too. std::vector nodes(node_count); std::iota(nodes.begin(), nodes.end(), 0); - std::sort(nodes.begin(), nodes.end(), - [&interference_map](const int64 i, const int64 j) { - return interference_map[i].size() > interference_map[j].size(); - }); + absl::c_sort(nodes, [&interference_map](const int64 i, const int64 j) { + return interference_map[i].size() > interference_map[j].size(); + }); const int64 kColorUnassigned = -1; std::vector assigned_colors(node_count, kColorUnassigned); @@ -138,8 +137,8 @@ Status GatherComputationsByAllocationType( worklist.pop_front(); const HloComputation* computation = worklist_front.first; bool is_thread_local = worklist_front.second; - bool in_thread_local_set = thread_local_set.count(computation) > 0; - bool in_global_set = global_set.count(computation) > 0; + bool in_thread_local_set = thread_local_set.contains(computation); + bool in_global_set = global_set.contains(computation); // If the computation has already been added to the respective set, then // nothing to do. @@ -207,9 +206,9 @@ Status GatherComputationsByAllocationType( // Add the computations to the vectors in post order. for (auto* computation : module->MakeComputationPostOrder()) { - if (thread_local_set.count(computation) > 0) { + if (thread_local_set.contains(computation)) { thread_local_computations->push_back(computation); - } else if (global_set.count(computation) > 0) { + } else if (global_set.contains(computation)) { global_computations->push_back(computation); } // If the computation is not reachable from the entry computation, then it @@ -219,13 +218,6 @@ Status GatherComputationsByAllocationType( return Status::OK(); } -size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const { - uint64 h = std::hash()(s.index()); - h = tensorflow::Hash64Combine(h, std::hash()(s.offset())); - h = tensorflow::Hash64Combine(h, std::hash()(s.size())); - return h; -} - string BufferAllocation::Slice::ToString() const { return absl::StrCat("{index:", index(), ", offset:", offset_, ", size:", size_, "}"); @@ -240,7 +232,7 @@ BufferAllocation::Slice BufferAllocation::GetSlice( void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size) { VLOG(4) << "Trying to add " << buffer << " to allocation #" << index(); - CHECK(assigned_buffers_.count(&buffer) == 0) + CHECK(!assigned_buffers_.contains(&buffer)) << "LogicalBuffer " << buffer << " already assigned to allocation " << index_; CHECK_LE(offset, size_) << "LogicalBuffer " << buffer @@ -279,11 +271,12 @@ BufferAllocationProto BufferAllocation::ToProto() const { proto_assigned->set_offset(buffer_offset_size.second.offset); proto_assigned->set_size(buffer_offset_size.second.size); } - std::sort(proto.mutable_assigned()->begin(), proto.mutable_assigned()->end(), - [](const BufferAllocationProto::Assigned& assign1, - const BufferAllocationProto::Assigned& assign2) { - return assign1.logical_buffer_id() < assign2.logical_buffer_id(); - }); + absl::c_sort(*proto.mutable_assigned(), + [](const BufferAllocationProto::Assigned& assign1, + const BufferAllocationProto::Assigned& assign2) { + return assign1.logical_buffer_id() < + assign2.logical_buffer_id(); + }); return proto; } @@ -315,10 +308,10 @@ string BufferAllocation::ToString() const { for (const auto& buffer_offset_size : assigned_buffers_) { sorted_buffers.push_back(buffer_offset_size.first); } - std::sort(sorted_buffers.begin(), sorted_buffers.end(), - [](const LogicalBuffer* a, const LogicalBuffer* b) { - return a->id() < b->id(); - }); + absl::c_sort(sorted_buffers, + [](const LogicalBuffer* a, const LogicalBuffer* b) { + return a->id() < b->id(); + }); for (const LogicalBuffer* buffer : sorted_buffers) { const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer); StrAppend(&output, absl::StrFormat( @@ -346,7 +339,7 @@ const PointsToSet& BufferAssignment::GetPointsToSet( bool BufferAssignment::HasAllocation(const LogicalBuffer& buffer) const { TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer)); - return allocation_index_for_buffer_.count(&buffer) > 0; + return allocation_index_for_buffer_.contains(&buffer); } const BufferAllocation& BufferAssignment::GetAssignedAllocation( @@ -401,7 +394,7 @@ bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction, const ShapeIndex& index) const { for (const LogicalBuffer* buffer : GetPointsToSet(instruction).element(index)) { - if (allocation_index_for_buffer_.count(buffer) > 0) { + if (allocation_index_for_buffer_.contains(buffer)) { return true; } } @@ -459,8 +452,7 @@ bool BufferAssignment::SharesSliceAtIndex( bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, const HloInstruction* hlo_b) const { - using SliceSet = - flat_hash_set; + using SliceSet = flat_hash_set; // Gets the slices all of instr's subshapes. If any subshape doesn't have an // assigned slice, returns the empty set. auto collect_slices = [&](const HloInstruction* instr) -> SliceSet { @@ -487,10 +479,9 @@ bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, // didn't return the empty set) for both HLOs, and the two resulting sets of // slices are disjoint. return !slices_a.empty() && !slices_b.empty() && - std::none_of(slices_a.begin(), slices_a.end(), - [&](const BufferAllocation::Slice& slice) { - return slices_b.count(slice) > 0; - }); + absl::c_none_of(slices_a, [&](const BufferAllocation::Slice& slice) { + return slices_b.contains(slice); + }); } StatusOr @@ -519,7 +510,7 @@ BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer, void BufferAssignment::AddAssignment(BufferAllocation* allocation, const LogicalBuffer& buffer, int64 offset, int64 size) { - CHECK_EQ(0, allocation_index_for_buffer_.count(&buffer)) + CHECK(!allocation_index_for_buffer_.contains(&buffer)) << "LogicalBuffer " << buffer << " already has an allocation."; CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty()) << "Non-reusable allocation already assigned a buffer: " @@ -761,7 +752,8 @@ namespace { bool MayInterfereAcrossSubcomputations(BufferAssignment* assignment, const LogicalBuffer& a_buffer, const LogicalBuffer& b_buffer) { - auto call_graph = assignment->liveness().hlo_ordering().call_graph(); + const CallGraph& call_graph = + assignment->liveness().hlo_ordering().call_graph(); const HloInstruction* a_ancestor; const HloInstruction* b_ancestor; std::tie(a_ancestor, b_ancestor) = @@ -960,35 +952,35 @@ Status BufferAssigner::AssignBuffersForComputation( // operands (assuming operands are the same/larger size) enabling the // important reuse case where an elementwise instruction reuses one of its // operand's buffer. This improves locality. - std::sort(sorted_buffers.begin(), sorted_buffers.end(), - [has_sequential_order, &liveness, &post_order_position, assignment]( - const LogicalBuffer* a, const LogicalBuffer* b) { - // Primary sort is by decreasing buffer size. - const int64 a_size = assignment->buffer_size_(*a); - const int64 b_size = assignment->buffer_size_(*b); - if (a_size != b_size) { - return a_size > b_size; // use ">" for decreasing size. - } - // Otherwise live out buffers come before others, if the - // instructions are sequentially ordered. - if (has_sequential_order) { - const bool a_live_out = liveness.MaybeLiveOut(*a); - const bool b_live_out = liveness.MaybeLiveOut(*b); - if (a_live_out != b_live_out) { - return a_live_out; - } - } - // Final tiebreaker is in instruction post order. - return post_order_position.at(a->instruction()) < - post_order_position.at(b->instruction()); - }); + absl::c_sort(sorted_buffers, + [has_sequential_order, &liveness, &post_order_position, + assignment](const LogicalBuffer* a, const LogicalBuffer* b) { + // Primary sort is by decreasing buffer size. + const int64 a_size = assignment->buffer_size_(*a); + const int64 b_size = assignment->buffer_size_(*b); + if (a_size != b_size) { + return a_size > b_size; // use ">" for decreasing size. + } + // Otherwise live out buffers come before others, if the + // instructions are sequentially ordered. + if (has_sequential_order) { + const bool a_live_out = liveness.MaybeLiveOut(*a); + const bool b_live_out = liveness.MaybeLiveOut(*b); + if (a_live_out != b_live_out) { + return a_live_out; + } + } + // Final tiebreaker is in instruction post order. + return post_order_position.at(a->instruction()) < + post_order_position.at(b->instruction()); + }); // BufferAllocations are necessarily created in decreasing size order. Keep // indices of previously created BufferAllocations in allocation_indices. std::vector allocation_indices; for (const LogicalBuffer* buffer : sorted_buffers) { VLOG(3) << "Assigning allocation to: " << *buffer; - if (colocated_buffers.count(buffer) > 0) { + if (colocated_buffers.contains(buffer)) { // Colocated buffers are currently assigned in an earlier pass. VLOG(3) << "Skipping colocated buffer: " << *buffer; continue; @@ -1020,10 +1012,14 @@ Status BufferAssigner::AssignBuffersForComputation( // callers. BufferAllocation* allocation = assignment->NewAllocation(*buffer, buffer_size); + bool parameter_has_alias = + assignment->module().input_output_alias_config().ParameterHasAlias( + instruction->parameter_number(), buffer->index()); allocation->set_entry_computation_parameter( - instruction->parameter_number(), buffer->index()); - VLOG(3) << "New allocation #" << allocation->index() - << " for entry computation parameter: " << *buffer; + instruction->parameter_number(), buffer->index(), + parameter_has_alias); + VLOG(3) << "Mark allocation #" << allocation->index() + << " as entry computation parameter: " << *buffer; continue; } @@ -1036,7 +1032,7 @@ Status BufferAssigner::AssignBuffersForComputation( continue; } - if (ShapeUtil::IsTuple(buffer->shape())) { + if (buffer->shape().IsTuple()) { BufferAllocation* allocation = assignment->NewAllocation(*buffer, buffer_size); allocation->set_is_tuple(true); @@ -1056,7 +1052,7 @@ Status BufferAssigner::AssignBuffersForComputation( assignment->GetAllSlices(operand, /*index=*/{})) { BufferAllocation* allocation = assignment->GetMutableAllocation(operand_slice.index()); - if (colocated_allocations.count(allocation->index()) == 0) { + if (!colocated_allocations.contains(allocation->index())) { // TODO(b/32491382) Colocated buffers are currently assigned in an // earlier pass, and so can break the "increasing allocation size" // invariant in this function (causing this CHECK to fail). However, @@ -1087,7 +1083,7 @@ Status BufferAssigner::AssignBuffersForComputation( // Instructions are iterated in increasing buffer size, so any // previously create allocation must be large enough to hold this // instruction's output (with the exception of colocated buffers). - if (colocated_allocations.count(allocation->index()) == 0) { + if (!colocated_allocations.contains(allocation->index())) { // TODO(b/32491382) Colocated buffers are currently assigned in an // earlier pass, and so can break the "increasing allocation size" // invariant in this function (causing this CHECK to fail). However, @@ -1313,10 +1309,10 @@ std::vector ComputePeakMemoryLogicalBuffers( live_buffers.end()); // Stabily sort the live buffers. - std::sort(live_buffers_vector.begin(), live_buffers_vector.end(), - [](const LogicalBuffer* a, const LogicalBuffer* b) { - return a->id() < b->id(); - }); + absl::c_sort(live_buffers_vector, + [](const LogicalBuffer* a, const LogicalBuffer* b) { + return a->id() < b->id(); + }); return live_buffers_vector; } @@ -1376,7 +1372,7 @@ void BufferAssigner::AddSetToColocatedBufferSets( std::vector overlap_set_indices; for (size_t index = 0; index < colocated_buffer_sets->size(); ++index) { for (const LogicalBuffer* buffer : colocated_set) { - if ((*colocated_buffer_sets)[index].count(buffer) > 0) { + if ((*colocated_buffer_sets)[index].contains(buffer)) { VLOG(5) << "Found overlap with existing set on buffer " << buffer->ToString() << "\n" << ColocatedBufferSetsToString((*colocated_buffer_sets)[index], @@ -1425,12 +1421,14 @@ BufferAssigner::MergeColocatedBufferSets( << colocated_buffer_sets.size(); // Returns true if the given buffer is for the entry parameter. - auto is_entry_parameter = [](const LogicalBuffer& buffer) { + auto is_readonly_entry_parameter = [](const LogicalBuffer& buffer) { auto* instruction = buffer.instruction(); auto* computation = instruction->parent(); auto* module = computation->parent(); return instruction->opcode() == HloOpcode::kParameter && - computation == module->entry_computation(); + computation == module->entry_computation() && + !module->input_output_alias_config().ParameterHasAlias( + instruction->parameter_number(), buffer.index()); }; std::vector set_can_be_merged(colocated_buffer_sets.size(), true); @@ -1452,7 +1450,7 @@ BufferAssigner::MergeColocatedBufferSets( for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) { for (auto& buffer : colocated_buffer_sets[i]) { if (buffer_liveness.MaybeLiveOut(*buffer) || - is_entry_parameter(*buffer) || + is_readonly_entry_parameter(*buffer) || buffer->instruction()->opcode() == HloOpcode::kConstant) { set_can_be_merged[i] = false; break; @@ -1539,15 +1537,16 @@ void BufferAssigner::BuildColocatedBufferSets( VLOG(4) << "Input/Output Alias Config: "; VLOG(4) << module->input_output_alias_config(); module->input_output_alias_config().ForEachAlias( - [&](const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index) { + [&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) { std::vector colocated_set; AddBufferToColocatedSet(module->entry_computation()->root_instruction(), output_index, points_to_analysis, &colocated_set); AddBufferToColocatedSet( - module->entry_computation()->parameter_instruction(param_number), - param_index, points_to_analysis, &colocated_set); + module->entry_computation()->parameter_instruction( + alias.parameter_number), + alias.parameter_index, points_to_analysis, &colocated_set); AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); }); @@ -1741,10 +1740,6 @@ void BufferAssigner::AssignColocatedBufferSets( // module-level scope, we can allow buffers to be shared across // computations (in some cases). allocation = assignment->NewAllocation(*buffer, buffer_size); - if (entry_parameter_number >= 0) { - allocation->set_entry_computation_parameter( - entry_parameter_number, *entry_parameter_shape_idx); - } if (is_constant) { allocation->set_constant(true); } @@ -1758,6 +1753,16 @@ void BufferAssigner::AssignColocatedBufferSets( } colocated_buffers->insert(buffer); } + + // If an allocation contains a parameter, set corresponding fields. + if (entry_parameter_number >= 0) { + bool parameter_has_alias = + assignment->module().input_output_alias_config().ParameterHasAlias( + entry_parameter_number, *entry_parameter_shape_idx); + allocation->set_entry_computation_parameter(entry_parameter_number, + *entry_parameter_shape_idx, + parameter_has_alias); + } } } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 0a9fdede803e84ca42472259084615c031b206eb..448dec3b1aa0c0f85e1060a70e965fcf3952c320 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -96,7 +96,11 @@ class BufferAllocation { // Whether this allocation is readonly i.e. backed by memory we cannot write // to. bool is_readonly() const { - return is_entry_computation_parameter() || is_constant(); + // Entry parameters are generally readonly, except when they are aliased + // with any output. + return (is_entry_computation_parameter() && + !is_parameter_aliased_with_output_) || + is_constant(); } bool is_tuple() const { return is_tuple_; } @@ -186,9 +190,10 @@ class BufferAllocation { end > other.offset_; } - struct Hasher { - size_t operator()(Slice s) const; - }; + template + friend H AbslHashValue(H h, const Slice& s) { + return H::combine(std::move(h), s.index(), s.offset(), s.size()); + } string ToString() const; @@ -273,8 +278,10 @@ class BufferAllocation { void AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size); void set_entry_computation_parameter(int64 parameter_number, - ShapeIndex param_shape_index) { + ShapeIndex param_shape_index, + bool parameter_aliased_with_output) { is_entry_computation_parameter_ = true; + is_parameter_aliased_with_output_ = parameter_aliased_with_output; parameter_number_ = parameter_number; param_shape_index_ = std::move(param_shape_index); } @@ -304,6 +311,9 @@ class BufferAllocation { // outlast the computation. bool is_entry_computation_parameter_ = false; + // Whether this entry computation parameter is aliased with output. + bool is_parameter_aliased_with_output_ = false; + // If this allocation holds an entry computation parameter, this field // indicates the index (starting from 0) of the parameter. int64 parameter_number_ = 0; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 8f482e6ba8c3e71c9980be5e6947ea61f3b4ef29..580bc2f43384006eab8711490689a200fc887d37 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_value.h" @@ -309,7 +310,7 @@ class BufferAssignmentTest : public HloTestBase { static bool BuffersDistinct(const std::vector& a, const std::vector& b, const BufferAssignment& assignment) { - std::set a_slices; + absl::flat_hash_set a_slices; for (const HloInstruction* instruction : a) { if (assignment.HasTopLevelAllocation(instruction)) { a_slices.insert( @@ -319,8 +320,8 @@ static bool BuffersDistinct(const std::vector& a, for (const HloInstruction* instruction : b) { if (assignment.HasTopLevelAllocation(instruction)) { - if (a_slices.count(assignment.GetUniqueTopLevelSlice(instruction) - .ConsumeValueOrDie())) { + if (a_slices.contains(assignment.GetUniqueTopLevelSlice(instruction) + .ConsumeValueOrDie())) { return false; } } @@ -464,6 +465,40 @@ TEST_F(BufferAssignmentTest, Basic) { GetAssignedOutputAllocation(*buffers, sub); } +TEST_F(BufferAssignmentTest, AliasedParamCanBeReused) { + // If an input buffer and output buffer aliases, the input buffer can be + // reused for other intermediate results. + // + // param0[100] ----- (neg1) -- (neg2) + // | | + // + -------- Aliased ---------+ + + auto builder = HloComputation::Builder(TestName()); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "p0")); + auto neg_1 = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param)); + auto neg_2 = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, neg_1)); + + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias( + {}, 0, {}, HloInputOutputAliasConfig::kUserAlias)); + + auto buffers = RunBufferAssignment(module.get()); + + BufferAllocation param_buffer = GetAssignedInputAllocation(*buffers, param); + BufferAllocation neg_1_buffer = GetAllocation(*buffers, neg_1, {}); + BufferAllocation neg_2_buffer = GetAllocation(*buffers, neg_2, {}); + + // Everything use one buffer. + EXPECT_EQ(param_buffer.index(), neg_1_buffer.index()); + EXPECT_EQ(neg_2_buffer.index(), neg_1_buffer.index()); +} + TEST_F(BufferAssignmentTest, AddCannotReuse) { // Pass in a special rule to indicate that "add" cannot reuse any buffer. // @@ -2485,9 +2520,9 @@ while_body { get-tuple-element.3 = s32[] get-tuple-element(state), index=0 constant.2 = s32[] constant(128) add.5 = s32[] add(get-tuple-element.3, constant.2) - constant.3 = s32[3]{0} constant({0, 0, 0}) - dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3) - dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3) + constant.3 = s32[] constant(0) + dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3) + dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3) ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9) } diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 40825a78716b1c0b9fb0121787977d275891c0f8..23b9af0281b0d5ee1ef6ca2315f0cc1042285609 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -52,8 +52,8 @@ class BufferLivenessTest : public HloTestBase { // interfere. Precondition: 'a' and 'b' are array-shaped. bool InstructionsMayInterfere(const BufferLiveness& liveness, HloInstruction* a, HloInstruction* b) { - EXPECT_FALSE(ShapeUtil::IsTuple(a->shape())); - EXPECT_FALSE(ShapeUtil::IsTuple(b->shape())); + EXPECT_FALSE(a->shape().IsTuple()); + EXPECT_FALSE(b->shape().IsTuple()); return liveness.MayInterfere( GetBuffer(liveness, /*instruction=*/a, /*index=*/{}), GetBuffer(liveness, /*instruction=*/b, /*index=*/{})); @@ -66,8 +66,8 @@ class BufferLivenessTest : public HloTestBase { HloInstruction* a, HloInstruction* b, const ShapeIndex& index) { // Check that top-level shapes are tuple and tuple element shapes are equal. - EXPECT_TRUE(ShapeUtil::IsTuple(a->shape())); - EXPECT_TRUE(ShapeUtil::IsTuple(b->shape())); + EXPECT_TRUE(a->shape().IsTuple()); + EXPECT_TRUE(b->shape().IsTuple()); EXPECT_TRUE( ShapeUtil::Compatible(ShapeUtil::GetSubshape(a->shape(), index), ShapeUtil::GetSubshape(b->shape(), index))); @@ -638,10 +638,10 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, {starts})); // Create output tuple. builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -794,10 +794,10 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, {starts})); // Create output tuple. auto tuple_root = builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); diff --git a/tensorflow/compiler/xla/service/buffer_value.cc b/tensorflow/compiler/xla/service/buffer_value.cc index fdf822c666b15afbc7553ca89d4f92ab08201869..b1abba20689915b03304aacd7a5fcca5443c2c60 100644 --- a/tensorflow/compiler/xla/service/buffer_value.cc +++ b/tensorflow/compiler/xla/service/buffer_value.cc @@ -29,8 +29,8 @@ BufferValue::BufferValue(HloInstruction* instruction, const ShapeIndex& index, Id id) : id_(id) { const Shape& shape = ShapeUtil::GetSubshape(instruction->shape(), index); - is_array_ = ShapeUtil::IsArray(shape); - is_tuple_ = ShapeUtil::IsTuple(shape); + is_array_ = shape.IsArray(); + is_tuple_ = shape.IsTuple(); } BufferValue::~BufferValue() {} diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 173b3fc05f53d523fb07ef9b14be884fd5f8aeb1..94af788c54f6c722997311bec50da3ed93aa3cee 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -236,6 +236,41 @@ void CallGraph::SetCallContexts() { } } +void CallGraph::SetNodeDepths() { + std::queue worklist; + + // Initialize node depths to -1. + for (CallGraphNode& node : nodes_) { + node.set_depth(-1); + } + + // Initialize worklist with all roots of the call graph (computations without + // callers). + for (const HloComputation* computation : module_->computations()) { + CallGraphNode& node = GetNode(computation); + if (node.callers().empty()) { + node.set_depth(0); + worklist.push(&node); + } + } + + while (!worklist.empty()) { + CallGraphNode* node = worklist.front(); + worklist.pop(); + for (const HloComputation* callee : node->callees()) { + CallGraphNode& callee_node = GetNode(callee); + if (callee_node.depth() < node->depth() + 1) { + callee_node.set_depth(node->depth() + 1); + worklist.push(&callee_node); + } + } + } + + for (CallGraphNode& node : nodes_) { + CHECK_NE(node.depth(), -1); + } +} + /* static */ std::unique_ptr CallGraph::Build(const HloModule* module) { // Constructor for CallGraph is private so absl::make_unique can't be used. @@ -271,6 +306,8 @@ std::unique_ptr CallGraph::Build(const HloModule* module) { } call_graph->SetCallContexts(); + call_graph->SetNodeDepths(); + XLA_VLOG_LINES(1, call_graph->ToString()); return call_graph; @@ -352,15 +389,38 @@ CallGraph::NearestAncestorsInSameComputation(HloInstruction* a, // Iterate through the callee->caller chains and find the earliest common // element. - for (HloInstruction* a_ancestor = a; a_ancestor != nullptr; - a_ancestor = next_caller(a_ancestor)) { - for (HloInstruction* b_ancestor = b; b_ancestor != nullptr; - b_ancestor = next_caller(b_ancestor)) { - if (a_ancestor->parent() == b_ancestor->parent()) { - return {a_ancestor, b_ancestor}; + HloInstruction* a_ancestor = a; + HloInstruction* b_ancestor = b; + int a_depth = GetNode(a->parent()).depth(); + int b_depth = GetNode(b->parent()).depth(); + + // Advance a_ancestor (b_ancestor) up the call chain until the call depth of + // a_ancestor or b_ancestor are the same. Necessarily each call to next_caller + // reduces the depth by exactly one. + if (a_depth > b_depth) { + for (int i = 0; i < a_depth - b_depth; ++i) { + a_ancestor = next_caller(a_ancestor); + if (a_ancestor == nullptr) { + return {nullptr, nullptr}; + } + } + } else if (b_depth > a_depth) { + for (int i = 0; i < b_depth - a_depth; ++i) { + b_ancestor = next_caller(b_ancestor); + if (b_ancestor == nullptr) { + return {nullptr, nullptr}; } } } + + while ((a_ancestor != nullptr) && (b_ancestor != nullptr)) { + if (a_ancestor->parent() == b_ancestor->parent()) { + return {a_ancestor, b_ancestor}; + } + + a_ancestor = next_caller(a_ancestor); + b_ancestor = next_caller(b_ancestor); + } return {nullptr, nullptr}; } diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index 05c7c998738f861ee804d1ec87bfa5fb17ddfb74..c02ffda575278905f6549b362e5e7d94f5713b36 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -121,6 +121,11 @@ class CallGraphNode { // Returns the context in which this computation is called. CallContext context() const { return context_; } + // Returns the depth of this node in the call graph. The depth is defined as + // the length of the longest call chain from a computation with no callers + // (usually the entry computation node) to this node. + int depth() const { return depth_; } + string ToString() const; private: @@ -130,6 +135,9 @@ class CallGraphNode { // Sets the context in which this computation is called. void set_context(CallContext value) { context_ = value; } + // Sets the depth of this node in the graph. + void set_depth(int value) { depth_ = value; } + // Adds a callsite which calls this computation. Updates callers to include // the calling computation. void AddCallerCallSite(const CallSite& caller_callsite); @@ -164,6 +172,9 @@ class CallGraphNode { // The context in which this computation is called. CallContext context_ = CallContext::kNone; + + // The depth of this node in the call graph. + int depth_ = 0; }; // The call graph for an HLO module. The graph includes a node for each @@ -245,9 +256,16 @@ class CallGraph { private: CallGraph(const HloModule* module); + // Not copyable. + CallGraph(const CallGraph&) = delete; + CallGraph& operator=(const CallGraph&) = delete; + // Sets the call contexts for every node in the graph. void SetCallContexts(); + // Sets the call node depths for every node in the graph. + void SetNodeDepths(); + // Helper method for VisitNodes(). Traverses the call graph from 'node' in DFS // post order (callee before caller) calling visitor_func on each node. Adds // nodes to 'visited' as each node is visited. Skips nodes already in diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index a3ac2568b0f3eec8556a42dbe3c2c64bd8564468..5de724f8924b78008ba4c56603b61bf93fbc5e7c 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -102,6 +102,7 @@ TEST_F(CallGraphTest, SingletonComputation) { const CallGraphNode& node = call_graph->GetNode(computation); EXPECT_EQ(computation, node.computation()); + EXPECT_EQ(node.depth(), 0); EXPECT_TRUE(node.callsites().empty()); EXPECT_TRUE(node.callees().empty()); EXPECT_TRUE(node.caller_callsites().empty()); @@ -122,11 +123,13 @@ TEST_F(CallGraphTest, UnreachableComputation) { EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_node.depth(), 0); EXPECT_EQ(entry_computation, entry_node.computation()); EXPECT_EQ(CallContext::kSequential, entry_node.context()); const CallGraphNode& unreachable_node = call_graph->GetNode(unreachable_computation); + EXPECT_EQ(unreachable_node.depth(), 0); EXPECT_EQ(unreachable_computation, unreachable_node.computation()); EXPECT_EQ(CallContext::kSequential, unreachable_node.context()); } @@ -145,6 +148,7 @@ TEST_F(CallGraphTest, ParallelComputation) { const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(entry_node.depth(), 0); EXPECT_EQ(CallContext::kSequential, entry_node.context()); EXPECT_EQ(5, entry_node.callsites().size()); EXPECT_EQ(1, entry_node.callees().size()); @@ -153,6 +157,7 @@ TEST_F(CallGraphTest, ParallelComputation) { const CallGraphNode& map_node = call_graph->GetNode(map_computation); EXPECT_EQ(map_computation, map_node.computation()); + EXPECT_EQ(map_node.depth(), 1); EXPECT_EQ(CallContext::kParallel, map_node.context()); EXPECT_TRUE(map_node.callsites().empty()); EXPECT_TRUE(map_node.callees().empty()); @@ -234,6 +239,7 @@ TEST_F(CallGraphTest, ContextBothComputations) { EXPECT_EQ(entry_node.GetCallSite(map), &map_callsite); const CallGraphNode& sub_node = call_graph->GetNode(subcomputation); + EXPECT_EQ(sub_node.depth(), 1); EXPECT_EQ(CallContext::kBoth, sub_node.context()); } @@ -264,6 +270,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) { EXPECT_EQ(3, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_node.depth(), 0); EXPECT_EQ(entry_computation, entry_node.computation()); EXPECT_EQ(1, entry_node.callsites().size()); @@ -275,11 +282,13 @@ TEST_F(CallGraphTest, ComputationWithConditional) { EXPECT_EQ(entry_node.GetCallSite(conditional), &conditional_callsite); const CallGraphNode& true_node = call_graph->GetNode(true_computation); + EXPECT_EQ(true_node.depth(), 1); EXPECT_TRUE(true_node.callees().empty()); EXPECT_EQ(1, true_node.callers().size()); EXPECT_EQ(entry_computation, true_node.callers()[0]); const CallGraphNode& false_node = call_graph->GetNode(false_computation); + EXPECT_EQ(false_node.depth(), 1); EXPECT_TRUE(false_node.callees().empty()); EXPECT_EQ(1, false_node.callers().size()); EXPECT_EQ(entry_computation, false_node.callers()[0]); @@ -332,9 +341,21 @@ TEST_F(CallGraphTest, ComplexGraph) { EXPECT_EQ(5, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + const CallGraphNode& a_node = call_graph->GetNode(a_computation); + const CallGraphNode& b_node = call_graph->GetNode(b_computation); + const CallGraphNode& c_node = call_graph->GetNode(c_computation); + const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); + + // Verify depths. + EXPECT_EQ(entry_node.depth(), 0); + EXPECT_EQ(a_node.depth(), 1); + EXPECT_EQ(b_node.depth(), 2); + EXPECT_EQ(c_node.depth(), 3); + EXPECT_EQ(cond_node.depth(), 2); + // Entry computation has one while instruction calling two computations // (cond_computation and a_computation). - const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); ASSERT_EQ(1, entry_node.callsites().size()); const std::vector& called_computations = entry_node.callsites()[0].called_computations(); @@ -342,7 +363,6 @@ TEST_F(CallGraphTest, ComplexGraph) { UnorderedElementsAre(cond_computation, a_computation)); EXPECT_EQ(CallContext::kSequential, entry_node.context()); - const CallGraphNode& c_node = call_graph->GetNode(c_computation); EXPECT_TRUE(c_node.callsites().empty()); EXPECT_THAT(c_node.callers(), UnorderedElementsAre(a_computation, b_computation)); @@ -364,7 +384,7 @@ TEST_F(CallGraphTest, ComplexGraph) { // Verify visitation order of some computations in the graph. auto index_of = [&visited](const HloComputation* comp) { - auto it = std::find(visited.begin(), visited.end(), comp); + auto it = absl::c_find(visited, comp); EXPECT_NE(it, visited.end()); return std::distance(visited.begin(), it); }; diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc index 3c2d1ae6d82ebc6c10d52194fd1cec5e291025f7..b517495f2ea0c75679685c67f757ff586f8c79e3 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.cc +++ b/tensorflow/compiler/xla/service/channel_tracker.cc @@ -72,7 +72,7 @@ ChannelHandle ChannelTracker::AllocateHandle(ChannelHandle::ChannelType type) { } Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { - if (opaque_to_channel_.count(handle.handle()) == 0) { + if (!opaque_to_channel_.contains(handle.handle())) { return NotFound("channel handle not found: %d", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; @@ -94,7 +94,7 @@ Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { } Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) { - if (opaque_to_channel_.count(handle.handle()) == 0) { + if (!opaque_to_channel_.contains(handle.handle())) { return NotFound("channel handle not found: %d", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h index 52037bf9b52556c6aa2e66dd3209e25cf085cfe3..89e17eba36f23077ce4cf0704e7455b76bee68d1 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.h +++ b/tensorflow/compiler/xla/service/channel_tracker.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/status.h" @@ -83,7 +84,8 @@ class ChannelTracker { // Mapping from ChannelHandle value to the corresponding registered // Channel object. - std::map opaque_to_channel_ GUARDED_BY(channel_mutex_); + absl::flat_hash_map opaque_to_channel_ + GUARDED_BY(channel_mutex_); TF_DISALLOW_COPY_AND_ASSIGN(ChannelTracker); }; diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 8f08c244908efb823b3870c19bdc3491fa87d44f..653f4555a77cc82e91fb1cd26206b93826375732 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -98,10 +98,17 @@ Compiler::GetPlatformCompilers() { auto* factories = GetPlatformCompilerFactories(); auto it = factories->find(platform->id()); if (it == factories->end()) { + string hint; + if (platform->Name() == "Host") { + hint = " (hint: try linking in tensorflow/compiler/jit:xla_cpu_jit)"; + } else if (platform->Name() == "CUDA") { + hint = " (hint: try linking in tensorflow/compiler/jit:xla_gpu_jit)"; + } + return NotFound( "could not find registered compiler for platform %s -- check " - "target linkage", - platform->Name()); + "target linkage%s", + platform->Name(), hint); } // And then we invoke the factory, placing the result into the mapping. diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc index efc893818d03a20d6bd65b7dc1da72ea5da5ceb0..92d1ca4ba5da802a5f1c544017ac52dda38e9b1d 100644 --- a/tensorflow/compiler/xla/service/computation_layout.cc +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -42,8 +42,8 @@ void ComputationLayout::SetToDefaultLayout() { } bool ComputationLayout::LayoutIsSet() const { - return std::all_of(parameter_layouts_.begin(), parameter_layouts_.end(), - [](const ShapeLayout& s) { return s.LayoutIsSet(); }) && + return absl::c_all_of(parameter_layouts_, + [](const ShapeLayout& s) { return s.LayoutIsSet(); }) && result_layout_.LayoutIsSet(); } diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc index 7a24faec17f0c4f0a57406328b1c21cd73506d82..1c1f5431700f4ee0cebc3146654feff620ee978c 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -207,7 +207,8 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { return Status::OK(); } - VLOG(2) << "Dealing with batch_group_count " << batch_group_count << "\n"; + VLOG(2) << "Dealing with batch_group_count " << batch_group_count + << " for convolution " << convolution->ToString() << "\n"; auto add = [&](std::unique_ptr inst) { return computation_->AddInstruction(std::move(inst)); @@ -315,14 +316,27 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { expanded_filter_shape, HloOpcode::kSelect, filter_mask, new_convolution, zero_filter)); - auto zero_literal = LiteralUtil::CreateR0(0.0f); - TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(F32)); + PrimitiveType reduce_type = new_filter->shape().element_type(); + auto reduce_window_shape = new_convolution->shape(); + reduce_window_shape.set_dimensions(output_batch_dimension, 1); + + // Ensure that data input to reduce window uses at least 32 bits. + if (primitive_util::BitWidth(reduce_type) < primitive_util::BitWidth(F32)) { + reduce_type = F32; + reduce_window_shape.set_element_type(F32); + Shape convert_shape = new_filter->shape(); + convert_shape.set_element_type(F32); + new_filter = + add(HloInstruction::CreateConvert(convert_shape, new_filter)); + } + + auto zero_literal = LiteralUtil::Zero(reduce_type); auto zero_scalar = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto reduce_function = [&]() -> HloComputation* { HloComputation::Builder b("add_computation"); - Shape shape = ShapeUtil::MakeShape(F32, {}); + Shape shape = ShapeUtil::MakeShape(reduce_type, {}); auto lhs = b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs")); auto rhs = @@ -332,18 +346,6 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); }; - // Ensure that data input to reduce window is of type F32. - if (primitive_util::BitWidth(new_filter->shape().element_type()) < - primitive_util::BitWidth(F32)) { - Shape convert_shape = new_filter->shape(); - convert_shape.set_element_type(F32); - new_filter = - add(HloInstruction::CreateBitcastConvert(convert_shape, new_filter)); - } - - auto reduce_window_shape = new_convolution->shape(); - reduce_window_shape.set_dimensions(output_batch_dimension, 1); - // Create the reduce window. Window window; for (int64 i = 0; i < new_convolution->shape().dimensions_size(); ++i) { @@ -369,7 +371,7 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { // Convert reduced data back to the original data type. auto reduce_window_converted = - HloInstruction::CreateBitcastConvert(convert_back_shape, reduce_window); + HloInstruction::CreateConvert(convert_back_shape, reduce_window); TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( convolution, std::move(reduce_window_converted))); diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc similarity index 75% rename from tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc rename to tensorflow/compiler/xla/service/convolution_group_converter_test.cc index d58f157242f5fb9690f7fda3e7d8f71ca6c8db84..585b81a5db632901be863893bf723fcba19388ea 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc @@ -94,5 +94,32 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,4], filter: f32[1,2,2]) -> f32[1,2 EXPECT_EQ(root->operand(1)->feature_group_count(), 1); } +TEST_F(ConvolutionGroupConverterTest, + ConvertBatchGroupCountEqualToInputBatchDim) { + string hlo_string = R"(HloModule Convolve1D1Window_0_module + +ENTRY %Convolve1D1Window_0.v3 (input: f32[16,19,19,512]{3,2,1,0}, filter: f32[16,19,19,512]{3,2,1,0}) -> f32[3,3,512,1]{3,2,1,0} { + %input = f32[16,19,19,512]{3,2,1,0} parameter(0) + %filter = f32[16,19,19,512]{3,2,1,0} parameter(1) + ROOT %convolution = f32[3,3,512,1]{3,2,1,0} convolution(f32[16,19,19,512]{3,2,1,0} %input, f32[16,19,19,512]{3,2,1,0} %filter), window={size=19x19 pad=1_1x1_1}, dim_labels=f01b_i01o->01fb, batch_group_count=512 + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + auto computation = module->entry_computation(); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); + auto cost_model = [](HloInstruction* conv) { return true; }; + ConvolutionGroupConverter converter(cost_model, /*convert_batch_groups_only=*/ + true); + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + // Make sure the convolution is converted to one with batch_group_count = 1. + EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kConvolution); + EXPECT_EQ(root->operand(0)->batch_group_count(), 1); + // Verify that the convolution is replaced by a reshape. + EXPECT_EQ(root->opcode(), HloOpcode::kReshape); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index df6059663876dfde71f4c75d3931b3d2de72c1df..5e26a63cebfa9b2e50f4b13335c10c246999d4df 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -349,11 +349,12 @@ Status AddCopiesForAliasedInputOutputs(HloModule* module) { ShapeTree param_indices_to_copy(param->shape()); module->input_output_alias_config().ForEachAlias( - [&](const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index) { - if (param_number == param->parameter_number()) { + [&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) { + if (alias.parameter_number == param->parameter_number()) { param_has_alias = true; - *(param_indices_to_copy.mutable_element(param_index)) = true; + *(param_indices_to_copy.mutable_element(alias.parameter_index)) = + true; *(output_indices_to_copy.mutable_element(output_index)) = true; } }); @@ -395,13 +396,14 @@ Status AddCopiesForAliasedInputOutputs(HloModule* module) { // Add control dependencies between the input/output copies. TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus( - [&](const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& input_index) -> Status { - if (!copied_parameters[param_number]) { + [&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) -> Status { + if (!copied_parameters[alias.parameter_number]) { return Status::OK(); } HloInstruction* from = - copied_parameters[param_number]->element(input_index); + copied_parameters[alias.parameter_number]->element( + alias.parameter_index); HloInstruction* to = output_copy_tree.element(output_index); TF_RET_CHECK(from != nullptr); @@ -522,7 +524,7 @@ class CopyRemover { // between copies added around aliased operations (kWhile) guarantees // this strict order. for (const HloValue* value_a : buffer.values()) { - if (ShapeUtil::IsToken(value_a->shape())) { + if (value_a->shape().IsToken()) { // Token values have no representation and cannot interfere. continue; } @@ -539,10 +541,9 @@ class CopyRemover { } std::vector values = buffer.values(); - std::sort(values.begin(), values.end(), - [this](const HloValue* a, const HloValue* b) { - return ordering_.IsDefinedBefore(*a, *b); - }); + absl::c_sort(values, [this](const HloValue* a, const HloValue* b) { + return ordering_.IsDefinedBefore(*a, *b); + }); // Create a list containing all of the values in the buffer. AddValueList(values, &value_to_node); @@ -842,12 +843,11 @@ class CopyRemover { copy_value_node->next->prev = operand_node; // Patch up uses. Remove use of copy from operand_node uses. - auto it = - std::find_if(operand_node->uses.begin(), operand_node->uses.end(), - [copy_value_node](const HloUse* use) { - return use->instruction == - copy_value_node->value->defining_instruction(); - }); + auto it = absl::c_find_if( + operand_node->uses, [copy_value_node](const HloUse* use) { + return use->instruction == + copy_value_node->value->defining_instruction(); + }); CHECK(it != operand_node->uses.end()); operand_node->uses.erase(it); diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index e4e9d7ba05c115be9dd0eb53ebd7de208d514efb..4391bdcba532661a0fde789e2c4ed324c40bcd32 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -1376,9 +1376,11 @@ TEST_F(CopyInsertionTest, CrossingParameters) { builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 4); @@ -1409,9 +1411,11 @@ TEST_F(CopyInsertionTest, ParametersAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 0); @@ -1475,7 +1479,8 @@ TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -1516,7 +1521,8 @@ TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 0); @@ -1557,7 +1563,8 @@ TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({add, negate1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 0); @@ -1848,8 +1855,7 @@ ENTRY %TokensShouldNotBeCopied () -> s32[] { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - HloRunner::CreateModuleFromString( - module_string, GetDebugOptionsForTest())); + ParseAndReturnVerifiedModule(module_string)); InsertCopies(module.get()); // There should be no copies added because tokens should not be copied. @@ -2112,8 +2118,7 @@ ENTRY TestComputation { ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 } )"; - auto module_or_status = - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); auto module = module_or_status.ConsumeValueOrDie(); InsertCopies(module.get()); } @@ -2213,8 +2218,7 @@ ENTRY TestComputation { ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 } )"; - auto module_or_status = - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); auto module = module_or_status.ConsumeValueOrDie(); InsertCopies(module.get()); } @@ -2231,7 +2235,7 @@ cond.inner { body.inner { param.body.inner = pred[] parameter(0) - ROOT neg = pred[] negate(param.body.inner) + ROOT not = pred[] not(param.body.inner) } cond.outer { @@ -2248,9 +2252,8 @@ ENTRY TestComputation { ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer } )"; - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); InsertCopies(module.get()); // There should only be a single copy inserted, and it's in the entry diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index f49b5110be5c4bab63b423e5ed2e67bc1828f6e3..d4535b204d7f3ad8d4e24beea5d0dd79e7a15ab0 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -1,6 +1,14 @@ # Description: # LLVM-based CPU backend for XLA. +load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") +load( + "//third_party/mkl:build_defs.bzl", + "mkl_deps", +) +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") +load(":build_defs.bzl", "runtime_copts") + licenses(["notice"]) # Apache 2.0 package( @@ -14,15 +22,6 @@ package_group( ], ) -load(":build_defs.bzl", "runtime_copts") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") -load( - "//third_party/mkl:build_defs.bzl", - "mkl_deps", -) - # Filegroup used to collect source files for dependency checking. filegroup( name = "c_srcs", @@ -95,6 +94,7 @@ cc_library( ":target_machine_features", "@com_google_absl//absl/types:span", "//tensorflow/compiler/tf2xla:cpu_function_runtime", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:scatter_expander", @@ -114,6 +114,7 @@ cc_library( "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_group_converter", "//tensorflow/compiler/xla/service:dot_decomposer", + "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -133,6 +134,7 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:sort_simplifier", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", @@ -241,6 +243,7 @@ cc_library( "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor/host:host_stream", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -364,15 +367,33 @@ cc_library( ], ) +cc_library( + name = "tiled_dot_emitter", + srcs = ["tiled_dot_emitter.cc"], + hdrs = ["tiled_dot_emitter.h"], + deps = [ + ":vector_support_library", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + cc_library( name = "dot_op_emitter", srcs = ["dot_op_emitter.cc"], - hdrs = ["dot_op_emitter.h"], + hdrs = [ + "dot_op_emitter.h", + ], deps = [ ":cpu_options", ":cpu_runtime", ":ir_emission_utils", ":target_machine_features", + ":tiled_dot_emitter", ":vector_support_library", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -380,6 +401,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", @@ -631,6 +653,7 @@ cc_library( deps = [ ":runtime_matvec", "//tensorflow/core:framework_lite", + "//tensorflow/core/kernels:eigen_contraction_kernel", "//third_party/eigen3", ], ) @@ -767,8 +790,6 @@ cc_library( ":target_machine_features", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:computation_layout", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", @@ -1008,7 +1029,6 @@ tf_cc_test( size = "small", srcs = ["cpu_eigen_tensor_alignment_test.cc"], deps = [ - ":dot_op_emitter", ":ir_emission_utils", ":target_machine_features_fake", "//tensorflow/compiler/xla:test", diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 796a7cf94d02b0ad42366387a9d3f8d589b8840a..414eacddfc7ba3c295c027c64c445a2046235d36 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -66,9 +66,14 @@ class FilteredPassManager : public llvm::legacy::PassManager { explicit FilteredPassManager(bool disable_expensive_passes) : disable_expensive_passes_(disable_expensive_passes) {} void add(llvm::Pass* p) override { + llvm::StringRef PassName = p->getPassName(); + if (PassName.contains("Warn about non-applied transformations")) { + delete p; + return; + } if (disable_expensive_passes_) { - llvm::StringRef PassName = p->getPassName(); if (PassName.contains("Unroll loops")) { + delete p; return; } } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index ba7dcde5c3d7e0406f46d642632f780d6d7db54f..eafda68510d93ee54f2aead60a84f3e97b3fe1f4 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -69,6 +69,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -92,6 +93,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/scatter_expander.h" +#include "tensorflow/compiler/xla/service/sort_simplifier.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -244,6 +246,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( HloPassPipeline pipeline("HLO passes through layout assignment"); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); + pipeline.AddPass(); pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( @@ -256,7 +259,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( // pass. pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(/*decompose_batch_dot=*/false); auto cost_model = [](HloInstruction* conv) { // We need a cost model for CPUs. Currently, do nothing. return false; @@ -279,10 +282,10 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); pipeline.AddPass(); - AlgebraicSimplifierOptions options( - [](const Shape&, const Shape&) { return false; }); + AlgebraicSimplifierOptions options; options.set_enable_dot_strength_reduction(false); pass.AddPass(options); + pass.AddPass(); pass.AddPass(); // BatchNormExpander can create zero-sized ops, so zero-sized HLO @@ -302,7 +305,8 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass( [&](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { - return PotentiallyImplementedAsEigenDot(dot, *target_machine_features) + return DotImplementationCanHandleTranspose(dot, + *target_machine_features) ? candidate_operands : TransposeFolding::OperandIndices{}; }, @@ -345,8 +349,7 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( pass.AddInvariantChecker( /*layout_sensitive=*/true, /*allow_mixed_precision=*/false); - AlgebraicSimplifierOptions options( - [](const Shape&, const Shape&) { return true; }); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); options.set_enable_dot_strength_reduction(false); pass.AddPass>(options); @@ -506,7 +509,7 @@ Status CreateHloProfilingArtifacts( auto shape_size_bytes = [](const Shape& shape) { // On the cpu, opaques are pointers. - if (ShapeUtil::IsOpaque(shape)) { + if (shape.IsOpaque()) { return static_cast(sizeof(void*)); } return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc index 8727c72b6e42517b1859e98ecadb41bbceed761c..485769a373acf5ae70c471b1a5dfcfb20ff772ef 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" @@ -28,37 +27,6 @@ namespace { class CpuEigenTensorAlignmentTest : public ::testing::Test {}; -TEST_F(CpuEigenTensorAlignmentTest, EigenDotAlignment) { - string hlo_string = R"( -HloModule DotOperation - -ENTRY DotOperation { - arg0 = f32[5,256] parameter(0) - arg1 = f32[256,1024] parameter(1) - ROOT dot = f32[5,1024] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(hlo_string)); - - HloInstruction* dot = module->entry_computation()->root_instruction(); - - TargetMachineFeaturesWithFakeAlignmentLogic target_machine_with_no_alignment( - [](int64 size) { return 1; }); - - EXPECT_FALSE( - PotentiallyImplementedAsEigenDot(*dot, target_machine_with_no_alignment)); - - TargetMachineFeaturesWithFakeAlignmentLogic - target_machine_with_full_alignment([](int64 size) { - return TargetMachineFeatures::kEigenExpectedTensorAlignment; - }); - - EXPECT_TRUE(PotentiallyImplementedAsEigenDot( - *dot, target_machine_with_full_alignment)); -} - TEST_F(CpuEigenTensorAlignmentTest, EigenConvAlignment) { string hlo_string = R"( HloModule ConvOperation diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 818b2b0d0db2893e11fa46c7867e6c74bbbb6905..23d0af34233858515af21df5e92346742a5b5dc3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -213,6 +213,8 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(), run_options->allocator(), stream->parent()->device_ordinal()); + const HloInputOutputAliasConfig& input_output_alias = + module().input_output_alias_config(); // Move OwningDeviceMemory values which contain the array(s) of the result // into the respective location in ScopedShapedBuffer which is returned to the @@ -232,12 +234,31 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( TF_ASSIGN_OR_RETURN( const BufferAllocation::Slice slice, this->assignment_->GetUniqueSlice(src, buffer_source->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - const BufferAllocation::Index buffer_index = slice.index(); OwningDeviceMemory& buffer = buffers[buffer_index]; - CHECK(!buffer.is_null() || buffer.size() == 0); - *device_memory = buffer.Forget(); + if (!slice.allocation()->is_entry_computation_parameter()) { + // If the buffer coming out of the result is from a parameter, the + // owning buffer will be null, and that means the caller aliased some + // parameter buffer to an output one (via the + // HloInputOutputAliasConfig API). If that is the case, the caller + // will receive a partially complete scoped shaped buffer, which they + // will have to fill up on return. Unfortunately the interface to the + // execute APIs are ShapedBuffer pointer based, which assumes caller + // ownership, and hence a buffer coming from there cannot be part of + // the new ScopedShapedBuffer we create for the result (which assumes + // ownership). + *device_memory = buffer.Forget(); + } else { + auto output_alias = input_output_alias.GetAliasedOutput( + slice.allocation()->parameter_number(), + slice.allocation()->param_shape_index()); + CHECK(output_alias) + << "Ouput buffer is coming from parameter " + << slice.allocation()->parameter_number() << " at index " + << slice.allocation()->param_shape_index() + << ", but no alias exists"; + CHECK_EQ(*output_alias, index); + } return Status::OK(); })); return std::move(result_buffer); @@ -326,7 +347,7 @@ StatusOr CpuExecutable::ExecuteAsyncOnStreamImpl( /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) { // On the cpu, opaques are pointers. - if (ShapeUtil::IsOpaque(shape)) { + if (shape.IsOpaque()) { return sizeof(void*); } return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 527df0bd1c23bba74f32226e5622fed32f7dcf84..c4bde837e57e82584c2a007858ed8d55608acd3c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -332,7 +332,7 @@ TEST_F(OpcodeFusionTest, Exponential_Reshape_Negate) { TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { HloComputation::Builder builder(TestName()); Shape param_shape = ShapeUtil::MakeShape(F32, {8}); - Shape starts_shape = ShapeUtil::MakeShape(F32, {2}); + Shape starts_shape = ShapeUtil::MakeShape(F32, {}); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {1, 8, 8}); Shape reshape_shape = ShapeUtil::MakeShape(F32, {8, 8}); Shape dynamic_slice_shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -340,13 +340,15 @@ TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { HloInstruction::CreateParameter(0, param_shape, "param")); HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, starts_shape, "starts")); + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "starts")); HloInstruction* broadcast2 = builder.AddInstruction( HloInstruction::CreateBroadcast(broadcast_shape, param0, {1})); HloInstruction* reshape3 = builder.AddInstruction( HloInstruction::CreateReshape(reshape_shape, broadcast2)); HloInstruction* dynamic_slice4 = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - dynamic_slice_shape, reshape3, param1, {4, 4})); + dynamic_slice_shape, reshape3, {param1, param2}, {4, 4})); builder.AddInstruction(HloInstruction::CreateUnary( dynamic_slice_shape, HloOpcode::kTanh, dynamic_slice4)); @@ -356,7 +358,8 @@ TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { RunFusionAndCheckOpcodesWereFused( module.get(), {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kReshape, - HloOpcode::kBroadcast, HloOpcode::kParameter, HloOpcode::kParameter}); + HloOpcode::kBroadcast, HloOpcode::kParameter, HloOpcode::kParameter, + HloOpcode::kParameter}); } TEST_F(OpcodeFusionTest, Broadcast_Negate) { @@ -381,14 +384,14 @@ TEST_F(OpcodeFusionTest, Broadcast_Negate) { TEST_F(OpcodeFusionTest, DynamicSlice_Negate) { HloComputation::Builder builder(TestName()); Shape param_shape = ShapeUtil::MakeShape(F32, {4}); - Shape slice_shape = ShapeUtil::MakeShape(F32, {1}); + Shape slice_shape = ShapeUtil::MakeShape(F32, {}); Shape result_shape = ShapeUtil::MakeShape(F32, {2}); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "param")); HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, slice_shape, "starts")); HloInstruction* dynamic_slice2 = builder.AddInstruction( - HloInstruction::CreateDynamicSlice(result_shape, param0, param1, {2})); + HloInstruction::CreateDynamicSlice(result_shape, param0, {param1}, {2})); builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, dynamic_slice2)); @@ -548,28 +551,36 @@ TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000}); + std::vector slice_indices, update_indices; + for (int i = 0; i < 3; ++i) { + slice_indices.push_back( + builder.AddInstruction(HloInstruction::CreateParameter( + 1 + i, ShapeUtil::MakeShape(U32, {}), "slice_indices"))); + update_indices.push_back( + builder.AddInstruction(HloInstruction::CreateParameter( + 5 + i, ShapeUtil::MakeShape(U32, {}), "update_indices"))); + } HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( slice_shape, builder.AddInstruction( HloInstruction::CreateParameter(0, full_shape, "slice_from")), - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), + slice_indices, /*slice_sizes=*/{10, 1, 1000})); builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( full_shape, builder.AddInstruction( - HloInstruction::CreateParameter(2, full_shape, "to_update")), - slice, - builder.AddInstruction(HloInstruction::CreateParameter( - 3, ShapeUtil::MakeShape(U32, {3}), "update_indices")))); + HloInstruction::CreateParameter(4, full_shape, "to_update")), + slice, update_indices)); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( - module.get(), {HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, - HloOpcode::kParameter, HloOpcode::kParameter, - HloOpcode::kParameter, HloOpcode::kParameter}); + module.get(), + {HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, + HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, + HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, + HloOpcode::kParameter, HloOpcode::kParameter}); } TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { @@ -578,49 +589,40 @@ TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { Shape full_shape = ShapeUtil::MakeShape(F32, {4, 100, 10, 100, 50}); - auto loop_idx = builder.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(S32, {1}), - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(S32, {}), "param0")))); - + auto loop_idx = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(S32, {}), "param0")); auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(S32, {1}), "param1")); - auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( - ShapeUtil::MakeShape(S32, {5}), - {loop_idx, param1, param1, param1, param1}, /*dimension=*/0)); + 1, ShapeUtil::MakeShape(S32, {}), "param1")); - auto idx_choice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - ShapeUtil::MakeShape(S32, {1}), - builder.AddInstruction(HloInstruction::CreateParameter( - 2, ShapeUtil::MakeShape(S32, {4}), "param2")), - loop_idx, - /*slice_sizes=*/{1})); - - PaddingConfig padding_config; - padding_config.add_dimensions()->set_edge_padding_high(4); - auto pad = builder.AddInstruction(HloInstruction::CreatePad( - ShapeUtil::MakeShape(S32, {5}), idx_choice, - builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), - padding_config)); + auto idx_choice = builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(S32, {}), + builder.AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(S32, {1}), + builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(S32, {4}), "param2")), + {loop_idx}, + /*slice_sizes=*/{1})))); + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( ShapeUtil::MakeShape(F32, {1, 100, 10, 100, 50}), builder.AddInstruction(HloInstruction::CreateParameter( 3, ShapeUtil::MakeShape(F32, {100, 100, 10, 100, 50}), "param3")), - pad, /*slice_sizes=*/{1, 100, 10, 100, 50})); + {idx_choice, zero, zero, zero, zero}, + /*slice_sizes=*/{1, 100, 10, 100, 50})); builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( full_shape, builder.AddInstruction( HloInstruction::CreateParameter(4, full_shape, "param4")), - slice, concat)); + slice, {loop_idx, param1, param1, param1, param1})); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( module.get(), - {HloOpcode::kConcatenate, HloOpcode::kPad, HloOpcode::kDynamicSlice, - HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, + {HloOpcode::kDynamicSlice, HloOpcode::kDynamicSlice, + HloOpcode::kDynamicUpdateSlice, HloOpcode::kReshape, HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter}); } @@ -930,9 +932,10 @@ ENTRY main { return result; } -INSTANTIATE_TEST_CASE_P(GatherLoopFusionTestInstantiation, GatherLoopFusionTest, - ::testing::ValuesIn(GetGatherLoopFusionTestSpecs()), - GatherLoopFusionTestSpec::Name); +INSTANTIATE_TEST_SUITE_P(GatherLoopFusionTestInstantiation, + GatherLoopFusionTest, + ::testing::ValuesIn(GetGatherLoopFusionTestSpecs()), + GatherLoopFusionTestSpec::Name); } // namespace } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index c291bf2d1ba2eaff4192051840768c037bece86f..95b8025f873c56bea063ff258d4abd6614257d85 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -46,8 +46,7 @@ static bool ShouldMakeAllUsersColMajor(const HloInstruction* instruction) { for (auto* user : instruction->users()) { optional operand_idx = ProfitableToMakeDotOperandColumnMajor(*user); if (!operand_idx || user->operand(*operand_idx) != instruction || - std::count(user->operands().begin(), user->operands().end(), - instruction) != 1) { + absl::c_count(user->operands(), instruction) != 1) { return false; } } @@ -94,60 +93,38 @@ static Shape ColMajorShape(const Shape& old_shape) { return new_shape; } +static bool OperandsAndResultMustHaveRowMajorLayout( + const HloInstruction& instr, + const TargetMachineFeatures& target_machine_features) { + if (instr.opcode() == HloOpcode::kConvolution) { + return PotentiallyImplementedAsEigenConvolution(instr, + target_machine_features); + } else if (instr.opcode() == HloOpcode::kDot) { + return DotOperandsAndResultMustHaveRowMajorLayout(instr, + target_machine_features); + } + return false; +} + Status CpuLayoutAssignment::AddBackendConstraints( LayoutConstraints* constraints) { ShouldMakeOperandColMajorCache cache; const HloComputation* computation = constraints->computation(); for (auto* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction, - target_machine_features_)) { - const HloInstruction* convolution = instruction; - const HloInstruction* lhs_instruction = convolution->operand(0); - const HloInstruction* rhs_instruction = convolution->operand(1); - - // In order to implement `convolution` with Eigen convolution, the layouts - // of the input, filter, and output need to be row-major. - // - // These constraints are not hard constraints. Ideally, we should decide - // which layouts to choose according to some cost model. - Shape output_shape(RowMajorShape(convolution->shape())); - Shape input_shape(RowMajorShape(lhs_instruction->shape())); - Shape filter_shape(RowMajorShape(rhs_instruction->shape())); - - // Set layouts of the instructions' shapes. - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(input_shape, convolution, 0)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(filter_shape, convolution, 1)); - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(output_shape, convolution)); + if (OperandsAndResultMustHaveRowMajorLayout(*instruction, + target_machine_features_)) { + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + RowMajorShape(instruction->shape()), instruction)); + for (int i = 0; i < instruction->operand_count(); i++) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + RowMajorShape(instruction->operand(i)->shape()), instruction, i)); + } } else if (optional op_idx = ShouldMakeOperandColumnMajor(&cache, *instruction)) { const HloInstruction* op = instruction->operand(*op_idx); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( ColMajorShape(op->shape()), instruction, *op_idx)); - } else if (PotentiallyImplementedAsEigenDot(*instruction, - target_machine_features_)) { - const HloInstruction* dot = instruction; - // In order to implement `dot` with Eigen dot, the layouts of the lhs, - // rhs, and output need to be row-major. - // - // These constraints are not hard constraints. Ideally, we should decide - // which layouts to choose according to some cost model. - Shape output_shape(RowMajorShape(dot->shape())); - - const HloInstruction* lhs_instruction = dot->operand(0); - Shape lhs_shape(RowMajorShape(lhs_instruction->shape())); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0)); - - const HloInstruction* rhs_instruction = dot->operand(1); - Shape rhs_shape(RowMajorShape(rhs_instruction->shape())); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); - - // Set layouts of the instructions' shapes. - TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(output_shape, dot)); } else { for (int64 operand_no = 0; operand_no < instruction->operand_count(); ++operand_no) { @@ -160,7 +137,7 @@ Status CpuLayoutAssignment::AddBackendConstraints( continue; } // Skip operands with non-array shapes. - if (!ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) { + if (!instruction->operand(operand_no)->shape().IsArray()) { continue; } Shape operand_shape( @@ -175,7 +152,7 @@ Status CpuLayoutAssignment::AddBackendConstraints( } // Skip instructions which don't produce array shapes (tuples, opaque, // etc.). - if (!ShapeUtil::IsArray(instruction->shape())) { + if (!instruction->shape().IsArray()) { continue; } } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index 92debb83e33b1400a59e5eef0f90971392ab7b22..ff654c83d61e7cc09ac7839feccaf2bc9cb3c63c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -23,8 +23,8 @@ namespace { const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; -const char* const kXlaEnableExperimentalLlvmIrGemm = - "xla_enable_experimental_llvm_ir_gemm"; +const char* const kXlaForceEnableExperimentalLlvmIrGemm = + "xla_force_enable_experimental_llvm_ir_gemm"; const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size"; } // namespace @@ -57,10 +57,10 @@ absl::optional LlvmIrGemvTilingFactor(const HloModuleConfig& config) { return absl::nullopt; } -bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { +bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { const auto& extra_options_map = config.debug_options().xla_backend_extra_options(); - return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0; + return extra_options_map.count(kXlaForceEnableExperimentalLlvmIrGemm) > 0; } static absl::string_view RemoveSuffix(absl::string_view str, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index 47c7eb13b6e4cc05a23f82b8d2a25249f4b82ac0..99e6702d14aed8ffb148adec2bdd02dbc7c3c7e3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -26,7 +26,7 @@ namespace options { bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); -bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config); +bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config); absl::optional LlvmIrGemvTilingFactor(const HloModuleConfig& config); absl::optional> LlvmIrGemmTileSize( const HloModuleConfig& config); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index a9febe891b5e9d1eb9e6b297952b50d1d26a3396..d8878e622c0500fc5328aa6c295a9e24a3a037f7 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -84,31 +84,8 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName = "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation"; extern const char* const kParallelForkJoinSymbolName = "__xla_cpu_runtime_ParallelForkJoin"; -extern const char* const kKeyValueSortPREDSymbolName = - "__xla_cpu_runtime_KeyValueSortPRED"; -extern const char* const kKeyValueSortS8SymbolName = - "__xla_cpu_runtime_KeyValueSortS8"; -extern const char* const kKeyValueSortU8SymbolName = - "__xla_cpu_runtime_KeyValueSortU8"; -extern const char* const kKeyValueSortS16SymbolName = - "__xla_cpu_runtime_KeyValueSortS16"; -extern const char* const kKeyValueSortU16SymbolName = - "__xla_cpu_runtime_KeyValueSortU16"; -extern const char* const kKeyValueSortF16SymbolName = - "__xla_cpu_runtime_KeyValueSortF16"; -extern const char* const kKeyValueSortS32SymbolName = - "__xla_cpu_runtime_KeyValueSortS32"; -extern const char* const kKeyValueSortU32SymbolName = - "__xla_cpu_runtime_KeyValueSortU32"; -extern const char* const kKeyValueSortF32SymbolName = - "__xla_cpu_runtime_KeyValueSortF32"; -extern const char* const kKeyValueSortS64SymbolName = - "__xla_cpu_runtime_KeyValueSortS64"; -extern const char* const kKeyValueSortU64SymbolName = - "__xla_cpu_runtime_KeyValueSortU64"; -extern const char* const kKeyValueSortF64SymbolName = - "__xla_cpu_runtime_KeyValueSortF64"; - +extern const char* const kKeyValueSortSymbolName = + "__xla_cpu_runtime_KeyValueSort"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; } // namespace runtime } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index b2e760a224ad8eaa61dae57b0f9cece04a7e54ae..3a2b44d8c1a80128d3577c374e751e73a89e9d59 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -64,18 +64,7 @@ extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName; extern const char* const kAcquireOutfeedBufferForPopulationSymbolName; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName; extern const char* const kParallelForkJoinSymbolName; -extern const char* const kKeyValueSortPREDSymbolName; -extern const char* const kKeyValueSortS8SymbolName; -extern const char* const kKeyValueSortU8SymbolName; -extern const char* const kKeyValueSortS16SymbolName; -extern const char* const kKeyValueSortU16SymbolName; -extern const char* const kKeyValueSortF16SymbolName; -extern const char* const kKeyValueSortS32SymbolName; -extern const char* const kKeyValueSortU32SymbolName; -extern const char* const kKeyValueSortF32SymbolName; -extern const char* const kKeyValueSortS64SymbolName; -extern const char* const kKeyValueSortU64SymbolName; -extern const char* const kKeyValueSortF64SymbolName; +extern const char* const kKeyValueSortSymbolName; // All symbol names for XLA CPU runtime functions need to start with this // prefix. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc index 1ae3aa57111e3a3b7ac18b4907c5c282edf89b7e..4e8c98678309fa4d573f1aac1290c9afc87643a4 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc @@ -162,11 +162,12 @@ TEST_P(EigenMatMulTest, DoIt) { CheckMatrixMultiply(*a, *b, *c); } -INSTANTIATE_TEST_CASE_P(EigenMatMulTestInstantiaion, EigenMatMulTest, - ::testing::Combine(::testing::ValuesIn(MatMulShapes), - ::testing::Bool(), ::testing::Bool(), - ::testing::Bool()), - EigenMatMulTest::Name); +INSTANTIATE_TEST_SUITE_P(EigenMatMulTestInstantiaion, EigenMatMulTest, + ::testing::Combine(::testing::ValuesIn(MatMulShapes), + ::testing::Bool(), + ::testing::Bool(), + ::testing::Bool()), + EigenMatMulTest::Name); #ifdef INTEL_MKL class MKLMatMulTest : public CpuRuntimeTest, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 1457582ac19c27e5c3150b4667e6af505345a6bd..3361a5973f5e8c91802b26d68477347b196d3cac 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -97,7 +97,7 @@ Status CpuTransferManager::TransferLiteralToInfeed( VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); - if (!ShapeUtil::IsTuple(shape)) { + if (!shape.IsTuple()) { int64 size = GetByteSizeRequirement(shape); return TransferBufferToInfeed(executor, size, literal.untyped_data()); } @@ -178,7 +178,7 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, Status CpuTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* executor, const Shape& literal_shape, MutableBorrowingLiteral literal) { - if (!ShapeUtil::IsTuple(literal_shape)) { + if (!literal_shape.IsTuple()) { int64 size = GetByteSizeRequirement(literal_shape); // Note: OSS build didn't like implicit conversion from // literal_shape.dimensions() to the array slice on 2017-07-10. diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.cc b/tensorflow/compiler/xla/service/cpu/disassembler.cc index 3ae64142cd7e32d3aa8d50870efaf94698c06440..c3c6847b7b77e2fb0470630815de9f5d7a6c5b9c 100644 --- a/tensorflow/compiler/xla/service/cpu/disassembler.cc +++ b/tensorflow/compiler/xla/service/cpu/disassembler.cc @@ -77,17 +77,16 @@ StatusOr Disassembler::DisassembleObjectFile( } // Sort the symbols in increasing address order. - std::sort( - symbols.begin(), symbols.end(), - [](const llvm::object::SymbolRef& a, const llvm::object::SymbolRef& b) { - // getAddress returns a Expected object. Assert there is no error - // before extracting the address. - llvm::Expected a_address_or_error = a.getAddress(); - CHECK(a_address_or_error); - llvm::Expected b_address_or_error = b.getAddress(); - CHECK(b_address_or_error); - return a_address_or_error.get() < b_address_or_error.get(); - }); + absl::c_sort(symbols, [](const llvm::object::SymbolRef& a, + const llvm::object::SymbolRef& b) { + // getAddress returns a Expected object. Assert there is no error + // before extracting the address. + llvm::Expected a_address_or_error = a.getAddress(); + CHECK(a_address_or_error); + llvm::Expected b_address_or_error = b.getAddress(); + CHECK(b_address_or_error); + return a_address_or_error.get() < b_address_or_error.get(); + }); // Construct ArrayRef pointing to section contents. llvm::StringRef section_content_string; diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 37cefcb2e827ffd15aa489b1b3199ba9f27d9dd6..48510181bd01c87c9db764396b556fdf34e6c8c4 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -26,7 +26,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" +#include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -41,931 +44,165 @@ namespace xla { using llvm_ir::SetToFirstInsertPoint; namespace cpu { - namespace { -// Provides tiled access to an in-memory rank 2 array. -class MemoryTile { - public: - // Constructs a MemoryTile that can operate on tiles consisting of - // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at - // `major_dim_offset` in the major dimension. The tile size along the minor - // dimension is the vector size, and that is implicitly determined by `vsl`. - MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* b, - llvm::Value* matrix, int64 matrix_size_along_minor_dim, - llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) - : vsl_(vsl), b_(b) { - pointers_.reserve(tile_size_along_major_dim); - for (int64 i = 0; i < tile_size_along_major_dim; i++) { - llvm::Value* total_offset = - b->CreateMul(b->getInt64(matrix_size_along_minor_dim), - b->CreateAdd(b->getInt64(i), major_dim_offset)); - pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset)); - } - } - - // Load a tile consisting of `tile_size_along_major_dim` vectors from position - // {major: `major_dim_offset`, minor: `minor_dim_offset`}. - // - // Note: `major_dim_offset` is a parameter to the constructor. - std::vector LoadTile(llvm::Value* minor_dim_offset) const { - std::vector result; - result.reserve(pointers_.size()); - for (const auto& pointer : pointers_) { - result.push_back(vsl_->LoadVector(pointer, minor_dim_offset)); - } - return result; - } - - // Stores `tile` to position {major: `major_dim_offset`, minor: - // `minor_dim_offset`}. - // - // Note: `major_dim_offset` is a parameter to the constructor. - void StoreTile(absl::Span tile, - llvm::Value* minor_dim_offset) const { - CHECK_EQ(tile.size(), pointers_.size()); - for (int64 i = 0; i < pointers_.size(); i++) { - vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset); - } - } - - // Loads a tile of size [`tile_size_along_major_dim`, - // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`, - // minor: `minor_dim_offset`} and then broadcasts each element into a vector - // of size vsl_.vector_size(). The (i,j)'th element of the return value is - // the (i,j)'th element in the tile broadcasted into an LLVM vector. - // - // Note: `major_dim_offset` is a parameter to the constructor. - std::vector> LoadBroadcastTile( - llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const { - std::vector> result; - result.resize(pointers_.size()); - for (int64 i = 0; i < pointers_.size(); i++) { - for (int64 j = 0; j < tile_size_along_middle_dim; j++) { - result[i].push_back(vsl_->LoadBroadcast( - pointers_[i], b_->CreateAdd(minor_dim_offset, b_->getInt64(j)))); - } - } - return result; - } - - private: - VectorSupportLibrary* vsl_; - llvm::IRBuilder<>* b_; - std::vector pointers_; -}; - -// The base class for the classes representing the GEMV emitter configurations. -// -// The IR emitted (modulo the LLVM values representing the input and output -// buffers) by the row major and column major GEMV emitters should be a function -// of their configuration. This is important because their configuration is -// used as a key to cache the generated IR. -class GemvConfig { - public: - // Mixin for convenience. - template - struct User { - public: - PrimitiveType scalar_type() const { - return derived().config().scalar_type(); - } - int64 tile_rows() const { return derived().config().tile_rows(); } - int64 tile_cols() const { return derived().config().tile_cols(); } - int64 m() const { return derived().config().m(); } - int64 k() const { return derived().config().k(); } - int64 has_addend() const { return derived().config().has_addend(); } - - private: - const T& derived() const { return *static_cast(this); } - }; +// Returns true if we should call into multi-threaded Eigen routines. +bool ShouldUseMultiThreadedEigen(const HloModuleConfig& config) { + return config.debug_options().xla_cpu_multi_thread_eigen(); +} - PrimitiveType scalar_type() const { return scalar_type_; } - int64 tile_rows() const { return tile_rows_; } - int64 tile_cols() const { return tile_cols_; } - int64 m() const { return m_; } - int64 k() const { return k_; } - bool has_addend() const { return has_addend_; } - - string GetCacheKey() const { - return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_", - tile_rows(), "_", tile_cols(), "_", m(), "_", k(), - has_addend() ? "_with_addend" : ""); +// Represents a dot operation. We use this in lieu of an `HloInstruction` +// because we want to be able to create this for the "inner" dot operation in a +// batch dot, for which there is no separate HLO instruction. +struct DotInfo { + Shape lhs_shape; + Shape rhs_shape; + Shape result_shape; + DotDimensionNumbers dim_nums; + + DotInfo() = default; + + explicit DotInfo(const HloInstruction& instr) { + CHECK_EQ(instr.opcode(), HloOpcode::kDot); + lhs_shape = instr.operand(0)->shape(); + rhs_shape = instr.operand(1)->shape(); + result_shape = instr.shape(); + dim_nums = instr.dot_dimension_numbers(); } - - protected: - explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows, - int64 tile_cols, int64 m, int64 k, bool has_addend) - : name_(std::move(name)), - scalar_type_(scalar_type), - tile_rows_(tile_rows), - tile_cols_(tile_cols), - m_(m), - k_(k), - has_addend_(has_addend) {} - - private: - string name_; - PrimitiveType scalar_type_; - int64 tile_rows_; - int64 tile_cols_; - int64 m_; - int64 k_; - bool has_addend_; }; -// Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the -// layout of the vector does not matter). This implementation uses a tiling -// scheme to improve performance. -// -// We logically separate the LHS matrix into four segments: -// -// +----------------------+---+ -// | | | -// | | | -// | A | B | -// | | | -// | | | -// | | | -// +----------------------+---+ -// | C | D | -// +----------------------+---+ -// -// where A is the largest submatrix of the LHS that can be evenly dividied into -// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: -// -// +---+---+---+---+ +--+--+--+--+ -// |M00|M10|M20|M30| |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M01|M11|M21|M31| and |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M02|M12|M22|M32| |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M03|M13|M23|M33| |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// -// (Legend: rows are horizontal and columns are vertical; and each column is one -// llvm::Value of a vector type) -// -// where: -// -// a. The left tile is from the column major left matrix. -// b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3] -// vector loaded from the RHS vector. -// -// As we iterate through the column dimension, we compute the change to the -// result vector by an elementwise multiplication between the two tiles above -// followed by a reduction along the major dimension: -// -// +-----------------------------------+ -// | M00*V0 + M10*V1 + M20*V2 + M30*V3 | -// +-----------------------------------+ -// | M01*V0 + M11*V1 + M21*V2 + M31*V3 | -// Result[R:R+4] += +-----------------------------------+ -// | M02*V0 + M12*V1 + M22*V2 + M32*V3 | -// +-----------------------------------+ -// | M03*V0 + M13*V1 + M23*V2 + M33*V3 | -// +-----------------------------------+ -// -// Where R is the starting row for the tile. -// -// We have an inner epilogue loop to deal with the "C" submatrix and an outer -// epilogue loop to deal with the B,D submarix. -// -// TODO(sanjoy): We should investigate if using gather loads and scatter stores -// can be used here have the same inner loop for both column-major and row-major -// matrix-vector products. -class ColumnMajorMatrixVectorProductEmitter - : public GemvConfig::User { - public: - class Config : public GemvConfig { - public: - explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, - int64 m, int64 k, bool has_addend) - : GemvConfig(/*name=*/"col_major_gemv", scalar_type, - /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, - /*k=*/k, /*has_addend=*/has_addend) {} - }; - - ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, - llvm::Value* rhs, llvm::Value* addend, - llvm::Value* result, - llvm::IRBuilder<>* b) - : config_(config), - lhs_(lhs), - rhs_(rhs), - addend_(addend), - result_(result), - b_(b), - ksl_(b_), - vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), b_, "") { - CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast(tile_rows()))); - CHECK(!has_addend() || addend != nullptr); - } - - void Emit(); - - const Config& config() const { return config_; } - - private: - void EmitOuterLoopBody(llvm::Value* column, int64 column_count, - bool is_first_column); - - MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) { - return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/m(), - /*major_dim_offset=*/column_start, - /*tile_size_along_major_dim=*/column_count); - } - - // Load a tile of values from the RHS. For the RHS a "tile" is a contiguous - // sequence of `count` values, each one broadcasted to the vector width. - std::vector LoadRhsTile(llvm::Value* offset, int64 count) { - llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset); - std::vector result; - result.reserve(count); - for (int64 i = 0; i < count; i++) { - result.push_back(vsl_.LoadBroadcast(base_pointer, i)); - } - return result; - } - - void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, - const std::vector& rhs_tile, - int64 columns, bool is_first_column); - - void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns, - bool is_first_tiled_column); - - Config config_; - llvm::Value* lhs_; - llvm::Value* rhs_; - llvm::Value* addend_; - llvm::Value* result_; - llvm::IRBuilder<>* b_; - KernelSupportLibrary ksl_; - VectorSupportLibrary vsl_; +// Dictates how a dot operation is implemented. +enum class DotImplementationStrategy { + // The dot operation is lowered into LLVM IR that implements a naive nested + // loop that computes the result one element at a time. This is our + // "fallback"; we don't really want this to kick in for any non-trival dot + // operation. + kNaiveLlvmIr, + + // The dot operation is lowered into LLVM IR that implements a tiled + // Matrix*Vector operation. This strategy also allows fusing in a bias add + // into the dot. The matrix can be row major or column major, both are + // supported. + kTiledLlvmIrGemv, + + // The dot operation is lowered into LLVM IR that implemetns a tiled + // Matrix*Matrix operation. No fusions are supported. The two inputs + // and the output have to be row major. + kTiledLlvmIrGemm, + + // The dot operation is lowered into a call into an Eigen routine. No fusions + // are supported today. The two inputs and the output have to be row major. + // However, we do allow transposing either the LHS or the RHS as part of the + // GEMM -- we expose this flexibility as flexibility in the contraction + // dimensions, but we can also see this as flexibility in the input layouts. + kEigen, }; -void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody( - llvm::Value* column, int64 column_count, bool is_first_column) { - MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column, - /*column_count=*/column_count); - - std::vector rhs_tile = - LoadRhsTile(column, /*count=*/column_count); - EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile, - /*columns=*/column_count, is_first_column); - EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column); -} - -void ColumnMajorMatrixVectorProductEmitter::Emit() { - // See the comment on the class declaration for the algorithm used here. - int64 column_remainder = k() % tile_cols(); - int64 column_limit = k() - column_remainder; - - ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), - [&](llvm::Value* column, bool is_first_column) { - EmitOuterLoopBody(column, tile_cols(), is_first_column); - }); - - if (column_remainder != 0) { - EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder, - column_limit == 0); - } -} - -void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - MemoryTile* lhs_memory_tile, const std::vector& rhs_tile, - int64 columns, bool is_first_column) { - int64 row_limit = m() - (m() % tile_rows()); - - ksl_.For( - "dot.inner.tiled", /*start=*/0, /*end=*/row_limit, - /*step=*/tile_rows(), [&](llvm::Value* row) { - std::vector lhs_tile = - lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row); - llvm::Value* accumulator = - is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) - : vsl_.GetZeroVector()) - : vsl_.LoadVector(result_, row); - for (int i = 0; i < columns; i++) { - accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); - } - vsl_.StoreVector(accumulator, result_, row); - }); -} - -void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( - llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) { - int64 row_start = m() - (m() % tile_rows()); - if (row_start == m()) { - return; - } - - llvm::Value* columns_llvm = b_->getInt64(columns); - - // for (col = current_tile_col; col < (columns + current_tile_col); col++) - // for (row = row_start, row < m_; row++) { - // result[row] += lhs[row, col] * rhs[col] - // // Also take into account that if col is 0 then result[row] is not - // // initialized. - // } - - ksl_.For( - "dot.inner.epilg.outer", /*start=*/current_tile_col, - /*end=*/b_->CreateAdd(columns_llvm, current_tile_col), - /*step=*/1, /*peel_first_iteration=*/false, - [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { - llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); - llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m())); - llvm::Value* lhs_base_pointer = - vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.For( - "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(), - /*step=*/1, [&](llvm::Value* scalar_row) { - llvm::Value* product = vsl_.Mul( - vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); - llvm::Value* setting_result_first_time = b_->CreateAnd( - is_first_scalar_col, b_->getInt1(is_first_tiled_column)); - ksl_.If( - setting_result_first_time, - /*true_block_generator=*/ - [&]() { - if (addend_) { - vsl_.StoreScalar( - vsl_.Add(vsl_.LoadScalar(addend_, scalar_row), - product), - result_, scalar_row); - } else { - vsl_.StoreScalar(product, result_, scalar_row); - } - }, - /*false_block_generator=*/ - [&]() { - vsl_.StoreScalar( - vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product), - result_, scalar_row); - }); - }); - }); -} +// Returns the implementation strategy for a dot with the configuration +// `dot_info`. +DotImplementationStrategy GetDotImplementationStrategy( + const HloModuleConfig& config, const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features); -// Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the -// layout of the vector does not matter). This implementation uses a tiling -// scheme to improve performance. -// -// We logically separate the LHS matrix into four segments: -// -// +----------------------+---+ -// | | | -// | | | -// | A | B | -// | | | -// | | | -// | | | -// +----------------------+---+ -// | C | D | -// +----------------------+---+ -// -// where A is the largest submatrix of the LHS that can be evenly dividied into -// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: -// -// +---+---+---+---+ -// |M00|M10|M20|M30| -// +---+---+---+---+ +--+--+--+--+ -// |M01|M11|M21|M31| and |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M02|M12|M22|M32| -// +---+---+---+---+ -// |M03|M13|M23|M33| -// +---+---+---+---+ -// -// (Legend: rows are horizontal and columns are vertical; and each row is one -// llvm::Value of a vector type) -// -// where: -// -// a. The left tile is loaded from the row major left matrix. -// b. The right vector is loaded from the RHS vector. -// -// We keep 4 vector accumulators accumulating the following four vector -// expressions as we iterate over the row dimension: -// -// +------+------+------+------+ -// |M0I*V0|M1I*V1|M2I*V2|M3I*V3| for I in [0,4) -// +------+------+------+------+ -// -// In the end we do a horizontal reduction over these 4 vector accumulators to -// get 4 values in the result vector. -// -// We have an inner epilogue loop to deal with the "B" sub-matrix and an outer -// epilogue loop to deal with the C,D submatrix. -class RowMajorMatrixVectorProductEmitter - : public GemvConfig::User { +// Helper class for emitting LLVM IR to perform the dot operation. +class DotOpEmitter { public: - class Config : public GemvConfig { - public: - explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, - int64 m, int64 k, bool has_addend) - : GemvConfig(/*name=*/"row_major_gemv", scalar_type, - /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, - /*k=*/k, /*has_addend=*/has_addend) {} - }; - - RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, - llvm::Value* rhs, llvm::Value* addend, - llvm::Value* result, llvm::IRBuilder<>* b) - : config_(config), - lhs_(lhs), - rhs_(rhs), - addend_(addend), - result_(result), - b_(b), - ksl_(b_), - vsl_(scalar_type(), /*vector_size=*/tile_cols(), b_, "") { - CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast(tile_cols()))); - CHECK(!has_addend() || addend != nullptr); - } - - void Emit(); - - const Config& config() const { return config_; } + explicit DotOpEmitter(DotInfo dot_info, string dot_hlo_name, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, + llvm::IRBuilder<>* b, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features); + + // Emits the IR to perform the dot operation. + Status Emit(); private: - MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) { - return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/k(), - /*major_dim_offset=*/row_start, - /*tile_size_along_major_dim=*/row_count); - } - - void EmitOuterLoopBody(llvm::Value* row, int64 row_count); - - void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64 rows, - std::vector* vector_accumulators); - - void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows, - std::vector* scalar_accumulators); - - Config config_; - llvm::Value* lhs_; - llvm::Value* rhs_; - llvm::Value* addend_; - llvm::Value* result_; - llvm::IRBuilder<>* b_; - KernelSupportLibrary ksl_; - VectorSupportLibrary vsl_; -}; - -void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, - int64 row_count) { - MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row, - /*row_count=*/row_count); - std::vector vector_accumulators; - std::vector scalar_accumulators; - for (int i = 0; i < row_count; i++) { - vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector()); - scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar()); - } - EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count, - &vector_accumulators); - EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, - &scalar_accumulators); - - std::vector accumulator_values; - std::transform( - vector_accumulators.begin(), vector_accumulators.end(), - std::back_inserter(accumulator_values), - [](const VectorVariable& vector_var) { return vector_var.Get(); }); - - std::vector horizontal_sums; - if (row_count == vsl_.vector_size()) { - if (addend_) { - horizontal_sums = vsl_.ComputeHorizontalSums( - std::move(accumulator_values), vsl_.LoadVector(addend_, row)); - } else { - horizontal_sums = - vsl_.ComputeHorizontalSums(std::move(accumulator_values)); - } - } else { - horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values)); - } - - for (int i = 0; i < row_count; i++) { - llvm::Value* result_value = - vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get()); - llvm::Value* offset = b_->CreateAdd(b_->getInt64(i), row); - if (addend_ && row_count != vsl_.vector_size()) { - result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value); - } - vsl_.StoreScalar(result_value, result_, offset); - } -} + // Emits instructions to perform a scalar dot product (a multiply of the + // LHS and RHS) and store the results in the target. + Status EmitScalarDot(); -void RowMajorMatrixVectorProductEmitter::Emit() { - // See the comment on the class declaration for the algorithm used here. - int64 row_remainder = m() % tile_rows(); - int64 row_limit = m() - row_remainder; + // Emits a call to the CPU runtime to perform the matrix multiply. + Status EmitCallToRuntime(); - ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), - [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); - - if (row_remainder != 0) { - EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder); - } -} + // Represents the dimensions of a matrix-matrix multiply operation. + struct MatMultDims { + // The number of rows in the LHS. + int64 m; -void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - MemoryTile* lhs_memory_tile, int64 rows, - std::vector* vector_accumulators) { - int64 column_limit = k() - (k() % tile_cols()); - - ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, - /*step=*/tile_cols(), [&](llvm::Value* col) { - std::vector lhs_tile = - lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); - llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); - for (int i = 0; i < rows; i++) { - llvm::Value* old_sum = (*vector_accumulators)[i].Get(); - (*vector_accumulators)[i].Set( - vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); - } - }); -} + // The number of columns in the LHS, which is also must be equal to the + // number of rows in the RHS. + int64 k; -void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( - llvm::Value* current_tile_row, int64 rows, - std::vector* scalar_accumulators) { - int64 column_start = k() - (k() % tile_cols()); - if (column_start == k()) { - return; - } + // The number of columns on the RHS. + int64 n; - for (int r = 0; r < rows; r++) { - llvm::Value* total_offset = b_->CreateMul( - b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k())); - llvm::Value* lhs_base_pointer = - vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.For( - "dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), - /*step=*/1, [&](llvm::Value* scalar_col) { - llvm::Value* product = - vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), - vsl_.LoadScalar(rhs_, scalar_col)); - llvm::Value* old_value = (*scalar_accumulators)[r].Get(); - (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); - }); - } -} + // True if the LHS matrix is column major. + bool lhs_column_major; -// This class implements a tiled matrix multiplication algorithm, intended for -// multiplying small matrices that don't need cache tiling. -// -// In the future this can be used as the innermost GEBP loop in a GEMM kernel as -// described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of -// high-performance matrix multiplication." ACM Transactions on Mathematical -// Software (TOMS) 34.3 (2008): 12.". -// -// This only supports canonical dot operations (i.e. where the lhs contraction -// dimension is 1 and the rhs contraction dimension is 0) over row major -// matrices. -class TiledSmallGemmEmitter { - public: - // Describe the dimensions of the kernel. - class Dimensions { - public: - explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {} + // True if the LHS contraction dimension is not 1. + bool lhs_non_canonical; - int64 m() const { return m_; } - int64 k() const { return k_; } - int64 n() const { return n_; } + // True if the RHS matrix is column major. + bool rhs_column_major; - string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); } + // True if the RHS contraction dimension is not 0. + bool rhs_non_canonical; - private: - const int64 m_; - const int64 k_; - const int64 n_; + // True if the result matrix is column major. + bool target_column_major; }; - // Represents the configuration of the emitter. The LLVM IR emitted by the - // emitter, modulo the LLVM values holding the input and output buffers, must - // be a function of the instance of `Config` passed to it. - // - // `dims` holds the matrix multiplication dimensions. - // - // `max_vectorization_width` is the maximum vector width (i.e. the width of - // the largest vector register we will use). This can be larger than the - // largest vector register supported by the machine -- LLVM will legalize - // these large vector widths into legally sized vectors. - // - // `max_vector_count` is the maximum number of vectors of size - // `max_vectorization_width` that we will attempt to process at once. - // - // `min_vectorization_width` is the smallest vector width the emitter will use - // -- below that it will devolve to using a scalar loop. - // - // The innermost reduction loop executes the matrix multiply in tiles of size - // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`, - // ] in the RHS. - class Config { - public: - explicit Config(PrimitiveType scalar_type, Dimensions dims, - int64 max_vectorization_width, int64 max_vector_count, - int64 min_vectorization_width, int64 tile_size_m, - int64 tile_size_k) - : scalar_type_(scalar_type), - dims_(dims), - max_vectorization_width_(max_vectorization_width), - max_vector_count_(max_vector_count), - min_vectorization_width_(min_vectorization_width), - tile_size_m_(tile_size_m), - tile_size_k_(tile_size_k) {} - - string GetCacheKey() const { - return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_", - dims().ToString(), "_", max_vectorization_width(), - "_", min_vectorization_width(), "_", tile_size_m(), - "_", tile_size_k()); - } + // Get the MatMultDims instance for the dot product this DotOpEmitter + // represents. Precondition: the dot is of rank 2 (and thus its operands are + // of rank 2 as well). + MatMultDims GetMatMultDims() const; - PrimitiveType scalar_type() const { return scalar_type_; } - Dimensions dims() const { return dims_; } - int64 max_vectorization_width() const { return max_vectorization_width_; } - int64 max_vector_count() const { return max_vector_count_; } - int64 min_vectorization_width() const { return min_vectorization_width_; } - - int64 tile_size_m() const { return tile_size_m_; } - int64 tile_size_k() const { return tile_size_k_; } - - private: - PrimitiveType scalar_type_; - Dimensions dims_; - int64 max_vectorization_width_; - int64 max_vector_count_; - int64 min_vectorization_width_; - int64 tile_size_m_; - int64 tile_size_k_; - }; + // Lowers the dot operation as a tiled Matrix*Vector loop. + void EmitTiledLlvmIrGemv(); - // Creates an instance of TiledSmallGemmEmitter that matrix-multiplies - // `lhs` with `rhs` and stores the result in `result`. - explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs, - llvm::Value* rhs, llvm::Value* result, - llvm::IRBuilder<>* b) - : lhs_(lhs), - rhs_(rhs), - result_(result), - config_(config), - b_(b), - ksl_(b_) { - CHECK(max_vectorization_width() > 0 && - IsPowerOfTwo(static_cast(max_vectorization_width()))); - CHECK_GT(max_vector_count(), 0); - CHECK(min_vectorization_width() > 0 && - IsPowerOfTwo(static_cast(min_vectorization_width()))); - CHECK_GE(max_vectorization_width(), min_vectorization_width()); - CHECK_GT(tile_size_k(), 0); - } + // Lowers the dot operation as a tiled Matrix*Matrix loop. + void EmitTiledLlvmIrGemm(); - void Emit(); + // Lowers the dot operation as a naive nested loop that computes the result + // one element at a time. + void EmitNaiveLlvmIrGemm(); - private: - // The HandleResiduesOnX helpers split the iteration space for dimension X - // into a multiple of the tile size on dimension X and an epilogue. These - // helpers ultimately call into `EmitTiledGemm` for emitting the - // tiled GEMM kernel. - - void HandleResiduesOnN(); - void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start, - llvm::Value* n_end); - void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k, - llvm::Value* k_start, llvm::Value* k_end, - llvm::Value* n_start, llvm::Value* n_end); - - // This emits a tiled GEMM kernel. For a detailed description see the comment - // on the implementation. - void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k, - llvm::Value* k_start, llvm::Value* k_end, - llvm::Value* n_start, llvm::Value* n_end, - int64 tile_size_m, llvm::Value* m_start, - llvm::Value* m_end); - - llvm::Value* GetInt64(int64 value) { return b_->getInt64(value); } - - Config config() const { return config_; } - Dimensions dims() const { return config().dims(); } - - int64 max_vectorization_width() const { - return config().max_vectorization_width(); + // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector + // registers. + int64 GetGemvTilingFactor() const { + const int64 kDefaultTilingFactor = 8; + return options::LlvmIrGemvTilingFactor(hlo_module_config_) + .value_or(kDefaultTilingFactor); } - int64 max_vector_count() const { return config().max_vector_count(); } - int64 min_vectorization_width() const { - return config().min_vectorization_width(); - } - int64 tile_size_m() const { return config().tile_size_m(); } - int64 tile_size_k() const { return config().tile_size_k(); } - PrimitiveType scalar_type() const { return config().scalar_type(); } - llvm::Value* lhs_; - llvm::Value* rhs_; - llvm::Value* result_; - Config config_; + std::tuple GetGemmTileSize() const { + // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz + // + // TODO(b/80093688): Tune for other architectures and centralize this + // information in one place. + const std::tuple kDefaultTileSize = + std::tuple(11, 9, 1); + return options::LlvmIrGemmTileSize(hlo_module_config_) + .value_or(kDefaultTileSize); + } + DotInfo dot_info_; + string dot_hlo_name_; + const llvm_ir::IrArray& target_array_; + const llvm_ir::IrArray& lhs_array_; + const llvm_ir::IrArray& rhs_array_; + const llvm_ir::IrArray* addend_array_; + llvm::Value* executable_run_options_value_; llvm::IRBuilder<>* b_; - KernelSupportLibrary ksl_; + const HloModuleConfig& hlo_module_config_; + const TargetMachineFeatures& target_machine_features_; }; - -void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); } - -void TiledSmallGemmEmitter::HandleResiduesOnN() { - // We can only iterate the `n` dimension for an extent that is divisible by - // the vectorization width. So we emit an outer loop that first processes the - // largest extent in `n` that is divisible by max_vectorization_width, then - // the largest remaining extent that is divisible by max_vectorization_width / - // 2 etc. - - int64 current_vectorization_width = - max_vector_count() * max_vectorization_width(); - int64 current_vector_count = max_vector_count(); - - int64 n_start = 0; - while (n_start != dims().n() && - current_vectorization_width >= min_vectorization_width()) { - int64 n_end = dims().n() - (dims().n() % current_vectorization_width); - if (n_start != n_end) { - VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_, - "gemm"); - HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end)); - n_start = n_end; - } - if (current_vector_count == 1) { - current_vectorization_width /= 2; - } else { - current_vector_count--; - current_vectorization_width = - current_vector_count * max_vectorization_width(); - } - } - - if (n_start != dims().n()) { - VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm"); - ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { - llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1)); - HandleResiduesOnK(&vsl, n_i, n_i_next); - }); - } -} - -void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, - llvm::Value* n_start, - llvm::Value* n_end) { - int64 k_start = 0; - int64 k_end = dims().k() - (dims().k() % tile_size_k()); - if (k_end != k_start) { - HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end), - n_start, n_end); - k_start = k_end; - } - - if (k_start != dims().k()) { - HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start), - GetInt64(dims().k()), n_start, n_end); - } -} - -void TiledSmallGemmEmitter::HandleResiduesOnM( - VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, - llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { - const int64 m_end = dims().m() - dims().m() % tile_size_m(); - EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(), - GetInt64(0), GetInt64(m_end)); - - if (m_end != dims().m()) { - EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, - dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m())); - } -} - -// The loop structure is: -// -// Iterate over dimension M as m: -// Iterate over dimension N as n: -// Iterate over dimension K as k: -// OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n]) -// -// I.e. a just a tiled version of a "naive" GEMM. -// -// The tiling scheme is as follows: -// -// Let the LHS be: -// -// +----+----+----+ -// | a0 | b0 | c0 | . -// +----+----+----+ . -// | a1 | b1 | c1 | . -// +----+----+----+ -// .. .. -// -// and the RHS be: -// -// +----+----+----+----+ -// | p0 | p1 | p2 | p3 | . -// +----+----+----+----+ . -// | q0 | q1 | q2 | q3 | . -// +----+----+----+----+ -// | r0 | r1 | r2 | r3 | . -// +----+----+----+----+ . -// ...... ...... -// -// and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted -// by `vsl`) be 4. Then we want to matrix multiply this tile to get a [2,4] -// matrix that we can increment the result matrix by. -// -// First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank -// 3 array, L, of dimension [2,3,4]: -// -// L[0,_,_] * L[1,_,_] -// * -// +----+----+----+----+ * +----+----+----+----+ -// | a0 | a0 | a0 | a0 | * | a1 | a1 | a1 | a1 | -// +----+----+----+----+ * +----+----+----+----+ -// | b0 | b0 | b0 | b0 | * | b1 | b1 | b1 | b1 | -// +----+----+----+----+ * +----+----+----+----+ -// | c0 | c0 | c0 | c0 | * | c1 | c1 | c1 | c1 | -// +----+----+----+----+ * +----+----+----+----+ -// -// -// Then we FMA L[0,_,_] with the RHS to get the first row of the result and -// L[1,_,_] with the RHS to get the second row of the result. For example, -// L[0,_,_] is computed as: -// -// +----+----+----+----+ +----+----+----+----+ -// | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 | + -// +----+----+----+----+ +----+----+----+----+ -// -// +----+----+----+----+ +----+----+----+----+ -// | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 | + -// +----+----+----+----+ +----+----+----+----+ -// -// +----+----+----+----+ +----+----+----+----+ -// | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 | -// +----+----+----+----+ +----+----+----+----+ -// -// to get: -// -// +-------------------+-------------------+-------------------+--------- -// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... -// +-------------------+-------------------+-------------------+--------- -void TiledSmallGemmEmitter::EmitTiledGemm( - VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, - llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, - int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { - ksl_.For( - "dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { - MemoryTile result_memory_tile( - vsl, b_, /*matrix=*/result_, - /*matrix_size_along_minor_dim=*/dims().n(), - /*major_dim_offset=*/m_i, - /*tile_size_along_major_dim=*/tile_size_m); - MemoryTile lhs_memory_tile(vsl, b_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/dims().k(), - /*major_dim_offset=*/m_i, - /*tile_size_along_major_dim=*/tile_size_m); - ksl_.For( - "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { - TileVariable result_tile_var(vsl, - result_memory_tile.LoadTile(n_i)); - ksl_.For( - "dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { - MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i, - tile_size_k); - std::vector> lhs_tile = - lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); - std::vector rhs_tile = - rhs_memory_tile.LoadTile(n_i); - std::vector result_tile = - result_tile_var.Get(); - for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) { - for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) { - result_tile[r_m_i] = - vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i], - result_tile[r_m_i]); - } - } - result_tile_var.Set(result_tile); - }); - - result_memory_tile.StoreTile(result_tile_var.Get(), n_i); - }); - }); -} - } // namespace -DotOpEmitter::DotOpEmitter(const HloInstruction& dot, +DotOpEmitter::DotOpEmitter(DotInfo dot_info, string dot_hlo_name, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, @@ -974,7 +211,8 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, llvm::IRBuilder<>* b, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) - : dot_(dot), + : dot_info_(std::move(dot_info)), + dot_hlo_name_(std::move(dot_hlo_name)), target_array_(target_array), lhs_array_(lhs_array), rhs_array_(rhs_array), @@ -984,58 +222,9 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, hlo_module_config_(hlo_module_config), target_machine_features_(target_machine_features) {} -/* static */ Status DotOpEmitter::EmitDotOperation( - const HloInstruction& dot, const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, - const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features) { - PrimitiveType type = target_array.GetShape().element_type(); - TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type); - DotOpEmitter dot_emitter(dot, target_array, lhs_array, rhs_array, - addend_array, executable_run_options_value, b, - hlo_module_config, target_machine_features); - return dot_emitter.Emit(); -} - -bool DotOpEmitter::EmitSmallGemmIfProfitable( - const DotOpEmitter::MatMultDims& mat_mult_dims) { - if (ShouldUseMultiThreadedEigen()) { - return false; - } - - if (!EnableExperimentalLlvmIrGemm()) { - // TODO(sanjoy): We should make these numbers micro-arch specific. - bool small_gemm = mat_mult_dims.k <= 128 && - ((mat_mult_dims.m <= 32 && mat_mult_dims.n <= 128) || - (mat_mult_dims.m <= 128 && mat_mult_dims.n <= 32)); - if (!small_gemm) { - return false; - } - } - - if (mat_mult_dims.lhs_non_canonical || mat_mult_dims.rhs_non_canonical) { - return false; - } - - PrimitiveType primitive_type = dot_.shape().element_type(); - - switch (primitive_type) { - default: - return false; - - case F32: - case F64: - case S32: - case S64: - break; - } - - if (!(mat_mult_dims.lhs_column_major == mat_mult_dims.rhs_column_major && - mat_mult_dims.rhs_column_major == mat_mult_dims.target_column_major)) { - return false; - } +void DotOpEmitter::EmitTiledLlvmIrGemm() { + PrimitiveType primitive_type = dot_info_.result_shape.element_type(); + MatMultDims mat_mult_dims = GetMatMultDims(); llvm::Value* lhs = lhs_array_.GetBasePointer(); llvm::Value* rhs = rhs_array_.GetBasePointer(); @@ -1050,9 +239,8 @@ bool DotOpEmitter::EmitSmallGemmIfProfitable( } int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - b_->CreateMemSet( - target, b_->getInt8(0), size_bytes, - target_machine_features_.minimum_alignment_for_allocation(size_bytes)); + b_->CreateMemSet(target, b_->getInt8(0), /*Size=*/size_bytes, + /*Align=*/1); int64 max_target_vector_width = target_machine_features_.vector_register_num_elements( @@ -1062,47 +250,28 @@ bool DotOpEmitter::EmitSmallGemmIfProfitable( std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) = GetGemmTileSize(); - TiledSmallGemmEmitter::Config config( - /*scalar_type=*/primitive_type, - TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, - /*max_vectorization_width=*/max_target_vector_width, - /*max_vector_count=*/tile_size_n_in_vector_width, - /*min_vectorization_width=*/std::min(4, max_target_vector_width), - /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k); - - VLOG(2) << "Emitting GEMM kernel in LLVM IR with config " - << config.GetCacheKey(); - const bool enable_fast_math = hlo_module_config_.debug_options().xla_cpu_enable_fast_math(); const bool optimize_for_size = options::OptimizeForSizeRequested(hlo_module_config_); - KernelSupportLibrary::EmitAndCallOutlinedKernel( + EmitSmallGemm( + /*scalar_type=*/primitive_type, + /*m=*/m, /*k=*/k, /*n=*/n, + /*max_vectorization_width=*/max_target_vector_width, + /*max_vector_count=*/tile_size_n_in_vector_width, + /*min_vectorization_width=*/std::min(4, max_target_vector_width), + /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k, /*lhs=*/lhs, + /*rhs=*/rhs, /*result=*/target, b_, /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), lhs, - rhs, target, - [this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) { - TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs, - /*rhs=*/rhs, - /*result=*/target, b_); - small_gemm_emitter.Emit(); - }); - - return true; + /*optimize_for_size=*/optimize_for_size); } -bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { - if (dot_.shape().dimensions_size() != 2) { - return false; - } - - PrimitiveType primitive_type = dot_.shape().element_type(); +void DotOpEmitter::EmitTiledLlvmIrGemv() { + PrimitiveType primitive_type = dot_info_.result_shape.element_type(); - if (!primitive_util::IsFloatingPointType(primitive_type) && - !primitive_util::IsIntegralType(primitive_type)) { - return false; - } + CHECK(primitive_util::IsFloatingPointType(primitive_type) || + primitive_util::IsIntegralType(primitive_type)); MatMultDims mat_mult_dims = GetMatMultDims(); bool is_column_major_matrix_vector = false; @@ -1143,9 +312,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { } } - if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) { - return EmitSmallGemmIfProfitable(mat_mult_dims); - } + CHECK(is_column_major_matrix_vector || is_row_major_matrix_vector); int64 tiling_factor = GetGemvTilingFactor(); CHECK_GT(tiling_factor, 0); @@ -1177,44 +344,27 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (is_column_major_matrix_vector) { VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m << " and k = " << k; - ColumnMajorMatrixVectorProductEmitter::Config config( + EmitColumnMajorGemv( /*scalar_type=*/primitive_type, /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor, - /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr); - - KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op, + /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr, + /*result=*/result_op, b_, /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), - lhs_op, rhs_op, - addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, - [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, - llvm::Value* addend_op, llvm::Value* result_op) { - ColumnMajorMatrixVectorProductEmitter emitter( - config, lhs_op, rhs_op, addend_op, result_op, b_); - emitter.Emit(); - }); + /*optimize_for_size=*/optimize_for_size); } else { VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m << " and k = " << k; - RowMajorMatrixVectorProductEmitter::Config config( + EmitRowMajorGemv( /*scalar_type=*/primitive_type, - /*tile_rows=*/tiling_factor, /*tile_cols=*/vector_register_element_size, - /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr); - - KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*tile_rows=*/tiling_factor, + /*tile_cols=*/vector_register_element_size, + /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op, + /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr, + /*result=*/result_op, b_, /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), - lhs_op, rhs_op, - addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, - [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, - llvm::Value* addend_op, llvm::Value* result_op) { - RowMajorMatrixVectorProductEmitter emitter(config, lhs_op, rhs_op, - addend_op, result_op, b_); - emitter.Emit(); - }); + /*optimize_for_size=*/optimize_for_size); } - - return true; } Status DotOpEmitter::Emit() { @@ -1240,11 +390,6 @@ Status DotOpEmitter::Emit() { // which performs the sum-of-products (the reduction loop) before storing // the result in the output buffer. - // This routine assumes that the dot operation is not in a parallelized - // enclosing computation. - CHECK( - dot_.parent()->root_instruction()->outer_dimension_partitions().empty()); - const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); @@ -1255,27 +400,41 @@ Status DotOpEmitter::Emit() { return EmitScalarDot(); } - if (EmitLlvmIrDotIfProfitable()) { - return Status::OK(); + switch (GetDotImplementationStrategy(hlo_module_config_, dot_info_, + target_machine_features_)) { + case DotImplementationStrategy::kNaiveLlvmIr: + EmitNaiveLlvmIrGemm(); + return Status::OK(); + + case DotImplementationStrategy::kTiledLlvmIrGemv: + EmitTiledLlvmIrGemv(); + return Status::OK(); + + case DotImplementationStrategy::kTiledLlvmIrGemm: + EmitTiledLlvmIrGemm(); + return Status::OK(); + + case DotImplementationStrategy::kEigen: + return EmitCallToRuntime(); } +} +void DotOpEmitter::EmitNaiveLlvmIrGemm() { CHECK_EQ(addend_array_, nullptr); - if (PotentiallyImplementedAsEigenDot(dot_, target_machine_features_)) { - return EmitCallToRuntime(); - } + const Shape& lhs_shape = lhs_array_.GetShape(); + const Shape& rhs_shape = rhs_array_.GetShape(); + const DotDimensionNumbers& dim_nums = dot_info_.dim_nums; // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special // case where the reduction dimension is 0 for both LHS and RHS. This results // in a vector dot product producing a scalar. - int64 lhs_reduction_dimension = - dot_.dot_dimension_numbers().lhs_contracting_dimensions(0); - int64 rhs_reduction_dimension = - dot_.dot_dimension_numbers().rhs_contracting_dimensions(0); + int64 lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0); + int64 rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0); // Verify the reduction dimension in the two operands are the same size. - TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == - rhs_shape.dimensions(rhs_reduction_dimension)); + CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension), + rhs_shape.dimensions(rhs_reduction_dimension)); bool lhs_reduction_along_minor_dimension = lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0); @@ -1285,7 +444,7 @@ Status DotOpEmitter::Emit() { // Create loop nests which loop through the LHS operand dimensions and the RHS // operand dimensions. The reduction dimension of the LHS and RHS are handled // in a separate innermost loop which performs the sum of products. - llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(&dot_), b_); + llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(dot_hlo_name_), b_); llvm_ir::IrArray::Index lhs_index = loop_nest.EmitOperandArrayLoopNest( lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs"); llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest( @@ -1390,8 +549,6 @@ Status DotOpEmitter::Emit() { // Set the IR builder insert point to the exit basic block of the outer most // loop. b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); - - return Status::OK(); } Status DotOpEmitter::EmitScalarDot() { @@ -1438,7 +595,7 @@ Status DotOpEmitter::EmitCallToRuntime() { // The two transpose_... parameters are actually booleans, but we use int32 // to avoid target-dependent calling convention details. - bool multi_threaded = ShouldUseMultiThreadedEigen(); + bool multi_threaded = ShouldUseMultiThreadedEigen(hlo_module_config_); bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn(); PrimitiveType type = target_array_.GetShape().element_type(); llvm::Type* float_type; @@ -1531,11 +688,11 @@ Status DotOpEmitter::EmitCallToRuntime() { } DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { - CHECK_EQ(dot_.shape().dimensions_size(), 2); + CHECK_EQ(dot_info_.result_shape.dimensions_size(), 2); const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); - const DotDimensionNumbers& dim_nums = dot_.dot_dimension_numbers(); + const DotDimensionNumbers& dim_nums = dot_info_.dim_nums; return { /*m=*/lhs_shape.dimensions(1 - dim_nums.lhs_contracting_dimensions(0)), @@ -1549,74 +706,6 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0}; } -// Return whether the given shape is rank 2. -static bool IsRank2(const Shape& shape) { return ShapeUtil::Rank(shape) == 2; } - -// In a gemm operation where output = lhs * rhs, check whether the given shapes -// are valid for the operation. -static bool AreValidGemmShapes( - const Shape& lhs_shape, const Shape& rhs_shape, const Shape& output_shape, - const TargetMachineFeatures& target_machine_features) { - // The inputs and the output must - // 1) be matrices with no padding, and - // 2) have an allowed element type. - PrimitiveType output_primitive_type = output_shape.element_type(); - if (!(output_primitive_type == F64 || output_primitive_type == F32 || - output_primitive_type == F16)) { - return false; - } - - if (!(IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape))) { - return false; - } - - auto is_aligned = [&](const Shape& shape) { - return GetMinimumAlignmentForArray(shape, target_machine_features) >= - TargetMachineFeatures::kEigenExpectedTensorAlignment; - }; - - if (!is_aligned(lhs_shape) || !is_aligned(rhs_shape) || - !is_aligned(output_shape)) { - return false; - } - - return true; -} - -bool PotentiallyImplementedAsEigenDot( - const HloInstruction& hlo, - const TargetMachineFeatures& target_machine_features) { - // For certain types of Dot, we can call Eigen - if (hlo.opcode() == HloOpcode::kDot) { - const Shape& lhs_shape = hlo.operand(0)->shape(); - const Shape& rhs_shape = hlo.operand(1)->shape(); - - if (ShapeUtil::IsZeroElementArray(lhs_shape) || - ShapeUtil::IsZeroElementArray(rhs_shape)) { - return false; - } - - if (ProfitableToImplementDotInTiledLlvmIr(hlo)) { - return false; - } - - // If gemm can accept the operand shapes, use it rather than a custom - // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape(), - target_machine_features)) { - const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers(); - // The size of the reduction dimension should match. The shape inference - // guarantees this invariant, so the check here is for programming - // errors. - CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), - rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); - return true; - } - } - - return false; -} - // For vector-matrix dot products, it is always profitable to make the Rhs // column major. absl::optional ProfitableToMakeDotOperandColumnMajor( @@ -1655,16 +744,319 @@ absl::optional ProfitableToMakeDotOperandColumnMajor( return {}; } -bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) { +namespace { +// Return whether the given shape is rank 2. +bool IsRank2(const Shape& shape) { return shape.rank() == 2; } + +bool IsSimpleLayout(const Layout& layout) { + return layout.tiles().empty() && layout.format() == DENSE; +} + +// In a gemm operation where output = lhs * rhs, check whether the given shapes +// are valid for the operation. +bool AreGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, + const Shape& output_shape, + const TargetMachineFeatures& target_machine_features) { + CHECK(!lhs_shape.has_layout() || IsSimpleLayout(lhs_shape.layout())) + << lhs_shape.DebugString(); + CHECK(!rhs_shape.has_layout() || IsSimpleLayout(rhs_shape.layout())) + << rhs_shape.DebugString(); + CHECK(!output_shape.has_layout() || IsSimpleLayout(output_shape.layout())) + << output_shape.DebugString(); + + switch (output_shape.element_type()) { + case F64: + case F32: + case F16: + return IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape); + default: + return false; + } +} + +bool IsAlignedGemm(const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features) { + if (ShapeUtil::IsZeroElementArray(dot_info.lhs_shape) || + ShapeUtil::IsZeroElementArray(dot_info.rhs_shape)) { + return false; + } + + return AreGemmShapes(dot_info.lhs_shape, dot_info.rhs_shape, + dot_info.result_shape, target_machine_features); +} + +bool CanEmitTiledLlvmIrGemm( + const HloModuleConfig& config, const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features) { + CHECK(IsAlignedGemm(dot_info, target_machine_features)); + + if (ShouldUseMultiThreadedEigen(config)) { + return false; + } + + int m = dot_info.result_shape.dimensions(0); + int k = dot_info.lhs_shape.dimensions( + dot_info.dim_nums.lhs_contracting_dimensions(0)); + int n = dot_info.result_shape.dimensions(1); + + if (!options::ForceEnableExperimentalLlvmIrGemm(config)) { + // TODO(sanjoy): We should make these numbers micro-arch specific. + bool small_gemm = + k <= 128 && ((m <= 32 && n <= 128) || (m <= 128 && n <= 32)); + if (!small_gemm) { + return false; + } + } + + bool lhs_non_canonical = dot_info.dim_nums.lhs_contracting_dimensions(0) == 0; + bool rhs_non_canonical = dot_info.dim_nums.rhs_contracting_dimensions(0) == 1; + + if (lhs_non_canonical || rhs_non_canonical) { + return false; + } + + if (dot_info.result_shape.element_type() == F16) { + // TODO(sanjoy): This is probably easy to fix, but I want to keep the CL + // adding this comment NFC. + return false; + } + + return true; +} + +DotImplementationStrategy GetDotImplementationStrategy( + const HloModuleConfig& config, const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features) { + PrimitiveType element_type = dot_info.result_shape.element_type(); // Any Matrix-Vector product of floating point or integral type, or // a transpose-dot fusion of the same can be lowered to a tiled LLVM // IR implementation. - const Shape& shape = dot.shape(); - return shape.dimensions_size() == 2 && - (shape.dimensions(0) == 1 || shape.dimensions(1) == 1) && - (primitive_util::IsFloatingPointType(shape.element_type()) || - primitive_util::IsIntegralType(shape.element_type())); + if (dot_info.result_shape.dimensions_size() == 2 && + (dot_info.result_shape.dimensions(0) == 1 || + dot_info.result_shape.dimensions(1) == 1) && + (primitive_util::IsFloatingPointType(element_type) || + primitive_util::IsIntegralType(element_type))) { + return DotImplementationStrategy::kTiledLlvmIrGemv; + } + + if (IsAlignedGemm(dot_info, target_machine_features)) { + return CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features) + ? DotImplementationStrategy::kTiledLlvmIrGemm + : DotImplementationStrategy::kEigen; + } + + return DotImplementationStrategy::kNaiveLlvmIr; } +Status EmitNonBatchDotOperation( + DotInfo dot_info, string hlo_name, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) { + PrimitiveType type = target_array.GetShape().element_type(); + TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type || + C128 == type); + DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name), + target_array, lhs_array, rhs_array, addend_array, + executable_run_options_value, b, hlo_module_config, + target_machine_features); + return dot_emitter.Emit(); +} + +Shape DropFirstDim(const Shape& shape) { + absl::Span array_shape_dims(shape.dimensions()); + array_shape_dims.remove_prefix(1); + return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), + array_shape_dims); +} + +Shape CollapseFirstNDims(const Shape& shape, int64 n) { + absl::Span input_shape_dims(shape.dimensions()); + int64 prefix_dim = + std::accumulate(input_shape_dims.begin(), input_shape_dims.begin() + n, + 1ll, std::multiplies()); + DimensionVector result_dims; + result_dims.push_back(prefix_dim); + std::copy(input_shape_dims.begin() + n, input_shape_dims.end(), + std::back_inserter(result_dims)); + return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), + result_dims); +} + +llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilder<>* b, + const llvm_ir::IrArray& array, int64 n) { + llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); + const Shape& shape = array.GetShape(); + CHECK(shape.has_layout() && + LayoutUtil::IsMonotonicWithDim0Major(shape.layout())); + CHECK_GE(shape.dimensions_size(), n); + Shape new_shape = CollapseFirstNDims(shape, n); + llvm::Value* new_value = b->CreateBitCast( + array.GetBasePointer(), + llvm_ir::ShapeToIrType(new_shape, module)->getPointerTo()); + return llvm_ir::IrArray(new_value, std::move(new_shape)); +} + +Status ValidateDotDimensionNumbers(const DotDimensionNumbers& dim_numbers) { + // Checks some invariants that do not hold in general, but DotDecomposer + // should have established for us. This is just a debugging aid. + TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1); + std::vector batch_dim_numbers(dim_numbers.lhs_batch_dimensions_size()); + absl::c_iota(batch_dim_numbers, 0); + TF_RET_CHECK( + absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions())); + TF_RET_CHECK( + absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions())); + return Status::OK(); +} + +// Slice out the inner array at batch index `batch_index` from `outer_array`. +llvm_ir::IrArray SliceOutInnerArray(llvm_ir::IrArray outer_array, + llvm::Value* batch_index, + llvm::IRBuilder<>* b) { + llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); + + Shape inner_shape = DropFirstDim(outer_array.GetShape()); + llvm_ir::IrArray::Index slice_index(b->getInt64Ty()); + slice_index.push_back(batch_index); + slice_index.InsertAt( + /*index=*/1, outer_array.GetShape().dimensions_size() - 1, + b->getInt64(0)); + llvm::Value* slice_ptr = outer_array.EmitArrayElementAddress(slice_index, b); + llvm::Type* slice_ptr_type = + llvm_ir::ShapeToIrType(inner_shape, module)->getPointerTo(); + return llvm_ir::IrArray(b->CreateBitCast(slice_ptr, slice_ptr_type), + std::move(inner_shape)); +} + +Status EmitBatchDotOperation( + const HloInstruction& dot, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) { + TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers())); + + // Lower a batch dot into a sequence of non-batch dot operations. + + int64 num_batch_dims = + dot.dot_dimension_numbers().lhs_batch_dimensions_size(); + + // First reshape the inputs to make sure we only have one batch dimension. + // This is a no-op bitcast because the operands have to be in row-major layout + // (enforced in CpuLayoutAssignment), and the batch dimensions are the leading + // dimensions (established by DotDecomposer and checked by + // ValidateDotDimensionNumbers above). + llvm_ir::IrArray lhs_array_reshaped = + CollapseFirstNDims(b, lhs_array, num_batch_dims); + llvm_ir::IrArray rhs_array_reshaped = + CollapseFirstNDims(b, rhs_array, num_batch_dims); + llvm_ir::IrArray target_array_reshaped = + CollapseFirstNDims(b, target_array, num_batch_dims); + + int64 batch_count = lhs_array_reshaped.GetShape().dimensions(0); + + KernelSupportLibrary ksl(b); + + return ksl.ForWithStatus( + "bdot", /*start=*/0, /*end=*/batch_count, /*step=*/1, + [&](llvm::Value* indvar) { + DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers(); + adjusted_dim_numbers.clear_lhs_batch_dimensions(); + adjusted_dim_numbers.clear_rhs_batch_dimensions(); + + // Create a DotInfo representing the "inner" non-batch dot operation. + DotInfo dot_info; + dot_info.lhs_shape = DropFirstDim(lhs_array_reshaped.GetShape()); + dot_info.rhs_shape = DropFirstDim(rhs_array_reshaped.GetShape()); + dot_info.result_shape = DropFirstDim(target_array_reshaped.GetShape()); + dot_info.dim_nums = dot.dot_dimension_numbers(); + dot_info.dim_nums.clear_lhs_batch_dimensions(); + dot_info.dim_nums.clear_rhs_batch_dimensions(); + + dot_info.dim_nums.set_lhs_contracting_dimensions( + 0, + dot_info.dim_nums.lhs_contracting_dimensions(0) - num_batch_dims); + dot_info.dim_nums.set_rhs_contracting_dimensions( + 0, + dot_info.dim_nums.rhs_contracting_dimensions(0) - num_batch_dims); + + llvm_ir::IrArray lhs_slice = + SliceOutInnerArray(lhs_array_reshaped, /*batch_index=*/indvar, b); + llvm_ir::IrArray rhs_slice = + SliceOutInnerArray(rhs_array_reshaped, /*batch_index=*/indvar, b); + llvm_ir::IrArray target_slice = SliceOutInnerArray( + target_array_reshaped, /*batch_index=*/indvar, b); + + // Emit the inner non-batch dot operation. + return EmitNonBatchDotOperation( + dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr, + executable_run_options_value, b, hlo_module_config, + target_machine_features); + }); +} + +bool IsBatchDot(const HloInstruction& instr) { + if (auto* dot_instr = DynCast(&instr)) { + return dot_instr->dot_dimension_numbers().lhs_batch_dimensions_size() > 0; + } + + return false; +} +} // namespace + +bool DotImplementationCanHandleTranspose( + const HloInstruction& dot_instr, + const TargetMachineFeatures& target_machine_features) { + DotImplementationStrategy impl_strategy = + GetDotImplementationStrategy(dot_instr.parent()->parent()->config(), + DotInfo(dot_instr), target_machine_features); + + // TODO(sanjoy): This is not quite right, it should be `impl_strategy == + // kEigen || impl_strategy == kTiledLlvmIrGemv || impl_strategy == + // kNaiveLlvmIr` but I'll fix this in a later CL in the interest of keeping + // the CL adding this comment NFC. + return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm || + impl_strategy == DotImplementationStrategy::kEigen; +} + +bool DotOperandsAndResultMustHaveRowMajorLayout( + const HloInstruction& dot_instr, + const TargetMachineFeatures& target_machine_features) { + DotImplementationStrategy impl_strategy = + GetDotImplementationStrategy(dot_instr.parent()->parent()->config(), + DotInfo(dot_instr), target_machine_features); + + return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm || + impl_strategy == DotImplementationStrategy::kEigen; +} + +Status EmitDotOperation(const HloInstruction& dot, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, + llvm::IRBuilder<>* b, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) { + // This routine assumes that the dot operation is not in a parallelized + // enclosing computation. + CHECK(dot.parent()->root_instruction()->outer_dimension_partitions().empty()); + + if (IsBatchDot(dot)) { + TF_RET_CHECK(addend_array == nullptr); + return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array, + executable_run_options_value, b, + hlo_module_config, target_machine_features); + } + + return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array, + lhs_array, rhs_array, addend_array, + executable_run_options_value, b, + hlo_module_config, target_machine_features); +} } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 4c2041b556aa8bf8fe8fb8e0674c0f4f04f0acae..105bd3005c86d87443b2528eba7b0106ad70590e 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -30,9 +30,16 @@ limitations under the License. namespace xla { namespace cpu { +// Returns true if the two operands and the output of `dot_instr` must have row +// major layout. +bool DotOperandsAndResultMustHaveRowMajorLayout( + const HloInstruction& dot_instr, + const TargetMachineFeatures& target_machine_features); -bool PotentiallyImplementedAsEigenDot( - const HloInstruction& hlo, +// Returns true our lowering strategy for `dot_instr` can fold in transposes to +// the either of the inputs. +bool DotImplementationCanHandleTranspose( + const HloInstruction& dot_instr, const TargetMachineFeatures& target_machine_features); // Returns the index for an operand to `hlo` that should ideally be column @@ -41,129 +48,24 @@ bool PotentiallyImplementedAsEigenDot( absl::optional ProfitableToMakeDotOperandColumnMajor( const HloInstruction& hlo); -// Returns true to indicate that we can generate a tiled LLVM IR implementation -// for |dot|. -bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot); - -// Helper class for emitting LLVM IR to perform the dot operation. -class DotOpEmitter { - public: - // Emit LLVM IR to perform the dot operation on lhs_array and rhs_array and - // place the result in target_array. IR is emitted at current insert point of - // the builder. Upon completion of the method, the insert point is set to the - // end of all instructions emitted for this operation. - // - // If `addend_array` is not nullptr then it must be an array of the same - // dimensions as the result, and the result is computed as `addend_array` + - // dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported - // for Matrix-vector products. - static Status EmitDotOperation( - const HloInstruction& dot, const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, - const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features); - - private: - DotOpEmitter(const HloInstruction& dot, const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, - const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features); - - // Emits the IR to perform the dot operation. - Status Emit(); - - // Emits instructions to perform a scalar dot product (a multiply of the - // LHS and RHS) and store the results in the target. - Status EmitScalarDot(); - - // Emit an LLVM IR implementation of the dot operation if we can. Returns - // true if an LLVM IR implementation was emitted. - bool EmitLlvmIrDotIfProfitable(); - - // Emits a call to the CPU runtime to perform the matrix multiply. - Status EmitCallToRuntime(); - - // Represents the dimensions of a matrix-matrix multiply operation. - struct MatMultDims { - // The number of rows in the LHS. - int64 m; - - // The number of columns in the LHS, which is also must be equal to the - // number of rows in the RHS. - int64 k; - - // The number of columns on the RHS. - int64 n; - - // True if the LHS matrix is column major. - bool lhs_column_major; - - // True if the LHS contraction dimension is not 1. - bool lhs_non_canonical; - - // True if the RHS matrix is column major. - bool rhs_column_major; - - // True if the RHS contraction dimension is not 0. - bool rhs_non_canonical; - - // True if the result matrix is column major. - bool target_column_major; - }; - - // Get the MatMultDims instance for the dot product this DotOpEmitter - // represents. Precondition: the dot is of rank 2 (and thus its operands are - // of rank 2 as well). - MatMultDims GetMatMultDims() const; - - bool EmitSmallGemmIfProfitable(const MatMultDims& mat_mult_dims); - - // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector - // registers. - int64 GetGemvTilingFactor() const { - const int64 kDefaultTilingFactor = 8; - return options::LlvmIrGemvTilingFactor(hlo_module_config_) - .value_or(kDefaultTilingFactor); - } - - std::tuple GetGemmTileSize() const { - // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz - // - // TODO(b/80093688): Tune for other architectures and centralize this - // information in one place. - const std::tuple kDefaultTileSize = - std::tuple(11, 9, 1); - return options::LlvmIrGemmTileSize(hlo_module_config_) - .value_or(kDefaultTileSize); - } - - // Returns true if we should use an experimental implementation of GEMM - // (general matrix matrix multiplication) if possible. - bool EnableExperimentalLlvmIrGemm() const { - return options::EnableExperimentalLlvmIrGemm(hlo_module_config_); - } - - // Returns true if we should call into multi-threaded Eigen routines. - bool ShouldUseMultiThreadedEigen() { - return hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); - } - - const HloInstruction& dot_; - const llvm_ir::IrArray& target_array_; - const llvm_ir::IrArray& lhs_array_; - const llvm_ir::IrArray& rhs_array_; - const llvm_ir::IrArray* addend_array_; - llvm::Value* executable_run_options_value_; - llvm::IRBuilder<>* b_; - const HloModuleConfig& hlo_module_config_; - const TargetMachineFeatures& target_machine_features_; -}; - +// Emit LLVM IR to perform the dot operation on lhs_array and rhs_array and +// place the result in target_array. IR is emitted at current insert point of +// the builder. Upon completion of the method, the insert point is set to the +// end of all instructions emitted for this operation. +// +// If `addend_array` is not nullptr then it must be an array of the same +// dimensions as the result, and the result is computed as `addend_array` + +// dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported +// for Matrix-vector products. +Status EmitDotOperation(const HloInstruction& dot, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, + llvm::IRBuilder<>* b, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features); } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter_internal.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..cc28918ed60a8086135846e2b9b1b9d75ec31ef6 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter_internal.h @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_INTERNAL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_INTERNAL_H_ + +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +// ----------------------------------------------------------------------------- +// INTERNAL HEADER. +// +// This file exposes internal implementation details from dot_op_emitter.cc for +// unit tests. Please do not depend on this! +// +// ----------------------------------------------------------------------------- + +namespace xla { +namespace cpu { +namespace internal { + +// Represents a dot operation. We use this in lieu of an `HloInstruction` +// because we want to be able to create this for the "inner" dot operation in a +// batch dot, for which there is no separate HLO instruction. +struct DotInfo { + Shape lhs_shape; + Shape rhs_shape; + Shape result_shape; + DotDimensionNumbers dim_nums; + + explicit DotInfo(const HloInstruction& instr) { + CHECK_EQ(instr.opcode(), HloOpcode::kDot); + lhs_shape = instr.operand(0)->shape(); + rhs_shape = instr.operand(1)->shape(); + result_shape = instr.shape(); + dim_nums = instr.dot_dimension_numbers(); + } +}; + +// Dictates how a dot operation is implemented. +enum class DotImplementationStrategy { + // The dot operation is lowered into LLVM IR that implements a naive nested + // loop that computes the result one element at a time. This is our + // "fallback"; we don't really want this to kick in for any non-trival dot + // operation. + kNaiveLlvmIr, + + // The dot operation is lowered into LLVM IR that implements a tiled + // Matrix*Vector operation. This strategy also allows fusing in a bias add + // into the dot. The matrix can be row major or column major, both are + // supported. + kTiledLlvmIrGemv, + + // The dot operation is lowered into LLVM IR that implemetns a tiled + // Matrix*Matrix operation. No fusions are supported. The two inputs + // and the output have to be row major. + kTiledLlvmIrGemm, + + // The dot operation is lowered into a call into an Eigen routine. No fusions + // are supported today. The two inputs and the output have to be row major. + // However, we do allow transposing either the LHS or the RHS as part of the + // GEMM -- we expose this flexibility as flexibility in the contraction + // dimensions, but we can also see this as flexibility in the input layouts. + kEigen, +}; + +// Returns the implementation strategy for a dot with the configuration +// `dot_info`. +DotImplementationStrategy GetDotImplementationStrategy( + const HloModuleConfig& config, const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features); +} // namespace internal +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_INTERNAL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index 1a8bedfe6afb4f096ddd4703c312b84d521a7ba5..a8b139aec9e96b6bb580baf74789df7c998cebf8 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -26,7 +26,7 @@ namespace cpu { int64 GetMinimumAlignmentForArray( const Shape& shape, const TargetMachineFeatures& target_machine_features) { - CHECK(ShapeUtil::IsArray(shape)); + CHECK(shape.IsArray()); CHECK(!LayoutUtil::HasLayout(shape) || LayoutUtil::IsDense(shape.layout())); // We don't require a layout to be set on `shape`. This only works on CPU diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index ed7fe59c80ed68420cea8b51e1732489ac2a874e..0226e8275cb0e1de39c4c2e9a06d4cfa1a4854d3 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -24,11 +24,9 @@ limitations under the License. #include #include +// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "tensorflow/core/lib/math/math_util.h" -#include "tensorflow/core/platform/logging.h" -// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" @@ -70,6 +68,8 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/logging.h" namespace xla { @@ -223,11 +223,11 @@ Status IrEmitter::HandleConstant(HloInstruction* constant) { } Status IrEmitter::HandleCopy(HloInstruction* copy) { - if (ShapeUtil::IsTuple(copy->shape())) { + if (copy->shape().IsTuple()) { // kCopy shallow copies a tuple so just memcpy the top-level buffer. TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy)); return EmitMemcpy(*(copy->operand(0)), *copy); - } else if (ShapeUtil::IsArray(copy->shape())) { + } else if (copy->shape().IsArray()) { // Use the elemental emitter for array shapes. return DefaultAction(copy); } @@ -239,10 +239,12 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) { int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); DCHECK_GE(byte_size, 0); - // Largest scalar is a complex64 so we don't need to worry about the + // Largest scalar is a complex128 so we don't need to worry about the // int64->int truncation here. - DCHECK_LE(byte_size, 8); - return byte_size; + DCHECK_LE(byte_size, 16); + + // Allocations may be 8-byte aligned if part of a small block. + return std::min(8LL, byte_size); } int64 IrEmitter::ByteSizeOf(const Shape& shape) const { @@ -316,7 +318,7 @@ Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) { auto on_false = tuple_select->operand(2); TF_RET_CHECK(pred->shape().element_type() == PRED); TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape())); - TF_RET_CHECK(ShapeUtil::IsTuple(tuple_select->shape())); + TF_RET_CHECK(tuple_select->shape().IsTuple()); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple_select)); llvm_ir::EmitTupleSelect(GetIrArrayFor(tuple_select), GetIrArrayFor(pred), GetEmittedValueFor(on_true), @@ -346,7 +348,7 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_, module_); - if (ShapeUtil::IsTuple(data_shape)) { + if (data_shape.IsTuple()) { TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape)); // For a tuple, we first copy each of the internal elements to @@ -470,7 +472,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { const Shape& operand_shape = operand->shape(); llvm::Value* value = GetEmittedValueFor(operand); - if (!ShapeUtil::IsTuple(operand_shape)) { + if (!operand_shape.IsTuple()) { return EmitXfeedTransfer(XfeedKind::kOutfeed, operand_shape, value); } @@ -493,6 +495,26 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { const HloSortInstruction* sort = Cast(hlo); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort)); Shape keys_shape = sort->keys()->shape(); + PrimitiveType keys_type = keys_shape.element_type(); + switch (keys_type) { + case PRED: + case S8: + case U8: + case S16: + case U16: + case F16: + case S32: + case U32: + case F32: + case S64: + case U64: + case F64: + break; + default: + return Unimplemented( + "Element type %s not supported in the Sort op on CPU.", + PrimitiveType_Name(keys_type)); + } std::vector destination_addresses(sort->operand_count()); for (int64 i = 0; i < sort->operand_count(); ++i) { ShapeIndex shape_index = @@ -535,110 +557,106 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { higher_dimensions *= normalized_keys_shape.dimensions(i); } int64 lower_dimensions = 1; - for (int64 i = ShapeUtil::Rank(normalized_keys_shape) - 1; + for (int64 i = normalized_keys_shape.rank() - 1; i > physical_dimension_to_sort; --i) { lower_dimensions *= normalized_keys_shape.dimensions(i); } - PrimitiveType keys_type = keys_shape.element_type(); - const char* fn_name = nullptr; - llvm::Type* keys_native_type = nullptr; - switch (keys_type) { - case PRED: - fn_name = runtime::kKeyValueSortPREDSymbolName; - keys_native_type = b_.getInt8PtrTy(); - break; - case S8: - fn_name = runtime::kKeyValueSortS8SymbolName; - keys_native_type = b_.getInt8PtrTy(); - break; - case U8: - fn_name = runtime::kKeyValueSortU8SymbolName; - keys_native_type = b_.getInt8PtrTy(); - break; - case S16: - fn_name = runtime::kKeyValueSortS16SymbolName; - keys_native_type = b_.getInt16Ty()->getPointerTo(); - break; - case U16: - fn_name = runtime::kKeyValueSortU16SymbolName; - keys_native_type = b_.getInt16Ty()->getPointerTo(); - break; - case F16: - fn_name = runtime::kKeyValueSortF16SymbolName; - keys_native_type = b_.getHalfTy()->getPointerTo(); - break; - case S32: - fn_name = runtime::kKeyValueSortS32SymbolName; - keys_native_type = b_.getInt32Ty()->getPointerTo(); - break; - case U32: - fn_name = runtime::kKeyValueSortU32SymbolName; - keys_native_type = b_.getInt32Ty()->getPointerTo(); - break; - case F32: - fn_name = runtime::kKeyValueSortF32SymbolName; - keys_native_type = b_.getFloatTy()->getPointerTo(); - break; - case S64: - fn_name = runtime::kKeyValueSortS64SymbolName; - keys_native_type = b_.getInt64Ty()->getPointerTo(); - break; - case U64: - fn_name = runtime::kKeyValueSortU64SymbolName; - keys_native_type = b_.getInt64Ty()->getPointerTo(); - break; - case F64: - fn_name = runtime::kKeyValueSortF64SymbolName; - keys_native_type = b_.getDoubleTy()->getPointerTo(); - break; - default: - return Unimplemented( - "Element type %s not supported in the Sort op on CPU.", - PrimitiveType_Name(keys_type)); + llvm::FunctionType* less_than_type = llvm::FunctionType::get( + b_.getInt1Ty(), {b_.getInt8PtrTy(), b_.getInt8PtrTy()}, + /*isVarArg=*/false); + auto less_than_function = llvm_ir::CreateFunction( + less_than_type, llvm::GlobalValue::InternalLinkage, + /*enable_fast_math=*/false, + /*optimize_for_size=*/true, absl::StrCat(IrName(sort), "_comparator"), + module_); + // Emit the code for the less_than function. + { + llvm::IRBuilder<>::InsertPointGuard guard(b_); + + auto* entry_bb = + llvm::BasicBlock::Create(b_.getContext(), "entry", less_than_function); + + b_.SetInsertPoint(entry_bb); + auto keys_ir_type = llvm_ir::PrimitiveTypeToIrType(keys_type, module_); + CHECK_EQ(less_than_function->arg_size(), 2); + llvm::Value* keys_lhs_ptr = less_than_function->arg_begin(); + keys_lhs_ptr = PointerCast(keys_lhs_ptr, keys_ir_type->getPointerTo()); + llvm::Value* keys_rhs_ptr = less_than_function->arg_begin() + 1; + keys_rhs_ptr = PointerCast(keys_rhs_ptr, keys_ir_type->getPointerTo()); + + // TODO(b/122298745): Replace the custom compare logic with a call to the + // computation specified for the Sort op. + llvm::Value* keys_lhs = Load(keys_ir_type, keys_lhs_ptr); + llvm::Value* keys_rhs = Load(keys_ir_type, keys_rhs_ptr); + bool is_signed_comparison = true; + if (primitive_util::IsFloatingPointType(keys_type)) { + // We would like a total order of floating point numbers so that the + // sort has a predictable behavior in the presence of NaNs. Rather + // than using floating point comparison, we use the following trick: + // If f is a float, and + // x = bit_cast(f); + // y = x < 0 ? 0x7FFFFFFF - x : x; + // then y is ordered as an int32 such that finite values have the + // obvious order, -0 is ordered before 0, and -NaN and NaN appear at + // the beginning and end of the ordering. + auto k = b_.getInt(llvm::APInt::getSignedMaxValue( + keys_lhs->getType()->getPrimitiveSizeInBits())); + auto comparison_type = k->getType(); + auto zero = llvm::ConstantInt::get(comparison_type, 0); + auto maybe_flip = [&](llvm::Value* v) { + return b_.CreateSelect(b_.CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero), + b_.CreateSub(k, v), v); + }; + keys_lhs = b_.CreateBitCast(keys_lhs, comparison_type); + keys_rhs = b_.CreateBitCast(keys_rhs, comparison_type); + keys_lhs = maybe_flip(keys_lhs); + keys_rhs = maybe_flip(keys_rhs); + } else if (!primitive_util::IsSignedIntegralType(keys_type)) { + is_signed_comparison = false; + } + llvm::Value* result = + b_.CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT + : llvm::ICmpInst::ICMP_ULT, + keys_lhs, keys_rhs); + llvm::ReturnInst::Create(b_.getContext(), + /*retVal=*/result, entry_bb); } llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get( b_.getVoidTy(), - {keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), + {b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(), - b_.getInt32Ty()->getPointerTo()}, + b_.getInt32Ty()->getPointerTo(), less_than_function->getType()}, /*isVarArg=*/false); - auto* key_value_sort_func = llvm::cast( - module_->getOrInsertFunction(fn_name, key_value_sort_type)); + auto* key_value_sort_func = + llvm::cast(module_->getOrInsertFunction( + runtime::kKeyValueSortSymbolName, key_value_sort_type)); key_value_sort_func->setCallingConv(llvm::CallingConv::C); key_value_sort_func->setDoesNotThrow(); - llvm::Value* values; - llvm::Value* sizes; - if (sort->values_count() == 0) { - values = llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()); - sizes = llvm::Constant::getNullValue(b_.getInt32Ty()->getPointerTo()); - } else { - values = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - b_.getInt8PtrTy(), b_.getInt32(sort->values_count()), - "cc_values_alloca", &b_); - sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - b_.getInt32Ty(), b_.getInt32(sort->values_count()), "cc_sizes_alloca", - &b_); - for (int64 i = 0; i < sort->values_count(); ++i) { - llvm::Value* value_as_i8ptr = - PointerCast(destination_addresses[i + 1], b_.getInt8PtrTy()); - llvm::Value* slot_in_values_alloca = - ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i); - Store(value_as_i8ptr, slot_in_values_alloca); - llvm::Value* slot_in_sizes_alloca = - ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i); - llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType( - sort->operand(i + 1)->shape().element_type())); - Store(size, slot_in_sizes_alloca); - } + llvm::Value* values = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b_.getInt8PtrTy(), b_.getInt32(sort->operand_count()), "cc_values_alloca", + &b_); + llvm::Value* sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b_.getInt32Ty(), b_.getInt32(sort->operand_count()), "cc_sizes_alloca", + &b_); + for (int64 i = 0; i < sort->operand_count(); ++i) { + llvm::Value* value_as_i8ptr = + PointerCast(destination_addresses[i], b_.getInt8PtrTy()); + llvm::Value* slot_in_values_alloca = + ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i); + Store(value_as_i8ptr, slot_in_values_alloca); + llvm::Value* slot_in_sizes_alloca = + ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i); + llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType( + sort->operand(i)->shape().element_type())); + Store(size, slot_in_sizes_alloca); } Call(key_value_sort_func, - {PointerCast(destination_addresses[0], keys_native_type), - b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), + {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), b_.getInt64(lower_dimensions), values, - b_.getInt32(sort->values_count()), sizes}); + b_.getInt32(sort->operand_count()), sizes, less_than_function}); if (sort->values_count() > 0) { llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_, @@ -779,8 +797,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { const auto init_value = select_and_scatter->operand(2); const Window& window = select_and_scatter->window(); PrimitiveType operand_element_type = operand->shape().element_type(); - const int64 rank = ShapeUtil::Rank(operand->shape()); - CHECK_EQ(rank, ShapeUtil::Rank(source->shape())); + const int64 rank = operand->shape().rank(); + CHECK_EQ(rank, source->shape().rank()); CHECK_EQ(rank, window.dimensions_size()); // TODO(b/31410564): Implement dilation for select-and-scatter. @@ -942,12 +960,8 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { auto rhs = dot->operand(1); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*dot, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F16, F32, F64, C64})); + /*supported_types=*/{F16, F32, F64, C64, C128})); const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); - if (dnums.lhs_batch_dimensions_size() > 0 || - dnums.rhs_batch_dimensions_size() > 0) { - return Unimplemented("Dot with batch dimensions not implemented."); - } if (dnums.lhs_contracting_dimensions_size() != 1) { // This is disallowed by ShapeInference today. @@ -970,10 +984,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { << llvm_ir::DumpToString(*target_array.GetBasePointer()); // Dot operation is complicated so we delegate to a helper class. - return DotOpEmitter::EmitDotOperation( - *dot, target_array, lhs_array, rhs_array, /*addend_array=*/nullptr, - GetExecutableRunOptionsArgument(), &b_, hlo_module_config_, - target_machine_features_); + return EmitDotOperation(*dot, target_array, lhs_array, rhs_array, + /*addend_array=*/nullptr, + GetExecutableRunOptionsArgument(), &b_, + hlo_module_config_, target_machine_features_); } StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( @@ -1118,7 +1132,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { auto rhs = convolution->operand(1); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*convolution, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F16, F32, C64})); + /*supported_types=*/{F16, F32, C64, C128})); // TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support // different data layouts. @@ -1362,7 +1376,7 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) { assignment_.GetUniqueSlice(crs, {i})); const Shape& operand_shape = crs->operand(i)->shape(); - CHECK(ShapeUtil::IsArray(operand_shape)) + CHECK(operand_shape.IsArray()) << "Operands to all-reduce must be arrays: " << crs->ToString(); operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape)); @@ -1399,7 +1413,7 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) { int64 delta = 0; for (int64 i = 0; i < operand_shape.dimensions_size(); i++) { - if (reduced_dims.count(i)) { + if (reduced_dims.contains(i)) { delta++; } else { InsertOrDie(&unreduced_dim_map, i, i - delta); @@ -1412,7 +1426,7 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) { for (int64 operand_dim_idx = 0; operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) { int64 operand_dim = operand_shape.layout().minor_to_major(operand_dim_idx); - if (!reduced_dims.count(operand_dim)) { + if (!reduced_dims.contains(operand_dim)) { if (FindOrDie(unreduced_dim_map, operand_dim) != result_shape.layout().minor_to_major(result_dim_idx++)) { return false; @@ -1709,10 +1723,8 @@ StatusOr IrEmitter::EmitVectorizedReduce( vectorization_factor_in_bytes / ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()); - bool is_reduction_over_minor_dimension = - std::find(dimensions.begin(), dimensions.end(), - LayoutUtil::Minor(arg->shape().layout(), 0)) != - dimensions.end(); + bool is_reduction_over_minor_dimension = absl::c_linear_search( + dimensions, LayoutUtil::Minor(arg->shape().layout(), 0)); unsigned element_alignment = tensorflow::MathUtil::GCD( ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), @@ -1724,7 +1736,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( return false; } - CHECK(!ShapeUtil::IsTuple(reduce->shape())); + CHECK(!reduce->shape().IsTuple()); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(reduce)); // We know we're not reducing over the most minor dimension, which means we @@ -1891,7 +1903,7 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( Status IrEmitter::HandleReduce(HloInstruction* reduce) { // TODO(b/112040122): Support variadic reduce. - if (!ShapeUtil::IsArray(reduce->shape())) { + if (!reduce->shape().IsArray()) { return Unimplemented("Variadic reduce is not supported on CPU"); } auto arg = reduce->mutable_operand(0); @@ -1990,7 +2002,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { // The memcpy will copy elements that are logically this shape (allowed to be // scalar). const Shape logical_element_shape = ShapeUtil::FilterDimensions( - [&inner_dims](int64 dim) -> bool { return inner_dims.count(dim); }, + [&inner_dims](int64 dim) { return inner_dims.contains(dim); }, operand->shape()); const int64 primitive_elements_per_logical_element = @@ -2205,10 +2217,10 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { llvm_ir::IrArray addend_array( GetIrArrayFor(fusion->operand(addend_param_number))); - TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( - *dot, target_array, lhs_array, rhs_array, &addend_array, - GetExecutableRunOptionsArgument(), &b_, hlo_module_config_, - target_machine_features_)); + TF_RETURN_IF_ERROR( + EmitDotOperation(*dot, target_array, lhs_array, rhs_array, + &addend_array, GetExecutableRunOptionsArgument(), &b_, + hlo_module_config_, target_machine_features_)); return Status::OK(); } else { return Unimplemented("Fusion kind not implemented on CPU"); @@ -2267,14 +2279,13 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); // Write the tuple table if the output is a tuple. - if (ShapeUtil::IsTuple(custom_call->shape())) { + if (custom_call->shape().IsTuple()) { std::vector base_ptrs; for (int i = 0; i < ShapeUtil::TupleElementCount(custom_call->shape()); ++i) { const Shape& elem_shape = ShapeUtil::GetTupleElementShape(custom_call->shape(), i); - TF_RET_CHECK(!ShapeUtil::IsTuple(elem_shape)) - << "Nested tuples not implemented"; + TF_RET_CHECK(!elem_shape.IsTuple()) << "Nested tuples not implemented"; TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, assignment_.GetUniqueSlice(custom_call, {i})); llvm::Value* addr = EmitBufferPointer(slice, elem_shape); @@ -2402,8 +2413,7 @@ StatusOr IrEmitter::EmitFastConcatenate( int64 concat_dim = concatenate->dimensions(0); const Layout& output_layout = output_shape.layout(); auto output_min2maj = LayoutUtil::MinorToMajor(output_layout); - auto concat_dim_layout_itr = - std::find(output_min2maj.begin(), output_min2maj.end(), concat_dim); + auto concat_dim_layout_itr = absl::c_find(output_min2maj, concat_dim); std::vector inner_dims(output_min2maj.begin(), concat_dim_layout_itr); std::vector outer_dims(std::next(concat_dim_layout_itr), @@ -2803,7 +2813,7 @@ llvm::Value* IrEmitter::EmitThreadLocalBufferPointer( llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); llvm::LoadInst* param_address_untyped = Load(param_address_offset); - if (!ShapeUtil::IsOpaque(target_shape)) { + if (!target_shape.IsOpaque()) { AttachAlignmentMetadataForLoad(param_address_untyped, target_shape); AttachDereferenceableMetadataForLoad(param_address_untyped, target_shape); @@ -2957,8 +2967,7 @@ Status IrEmitter::ElementTypesSameAndSupported( TF_RET_CHECK(!operands.empty()); PrimitiveType primitive_type = operands[0]->shape().element_type(); - if (std::find(supported_types.begin(), supported_types.end(), - primitive_type) == supported_types.end()) { + if (!absl::c_linear_search(supported_types, primitive_type)) { return Unimplemented("unsupported operand type %s in op %s", PrimitiveType_Name(primitive_type), HloOpcodeString(instruction.opcode())); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index db76de4bb2b8ed568bf2557a30fa216d0cbe518d..974dd7cd3f2254bfbc86fffae02c06c481af8902 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -250,14 +250,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, llvm::Value* EmitBufferPointer(const BufferAllocation::Slice& slice, const Shape& target_shape); - // Emits a function into the current module. This can be used for - // computations embedded inside other computations, such as the - // function that a map operation applies. - StatusOr EmitFunction( - HloComputation* function, // The function to emit. - absl::string_view - function_name_suffix); // Used for LLVM IR register names. - // Emits a call to a thread local function (e.g. to the computation nested // within a reduce or a map). Thread local callees (by definition) only write // to and read from thread local allocations. @@ -448,7 +440,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, computation_to_profile_idx_; // Maps HLOs to Values emitted for them. - std::unordered_map emitted_value_; + absl::flat_hash_map emitted_value_; llvm_ir::AliasAnalysis alias_analysis_; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index f8441c3e345504616485c6b34b4302acd5cc23a3..a6f4273a5a70aab0bc88383283d2a55b1ecb1681 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -34,7 +34,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, llvm::Type* index_type) { CHECK_NE(index_type, nullptr); - CHECK(!ShapeUtil::IsTuple(shape_)); + CHECK(!shape_.IsTuple()); CHECK(!ShapeUtil::IsScalar(shape_)); llvm_ir::ForLoopNest loop_nest(loop_name, b_); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index ede7f433ca6b2cc5629115f800348be9dfb2b93b..6121d1ca9a5c785cedd947200d3e7e320aa06bc2 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -146,11 +146,9 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( (opcode == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction, target_machine_features_)) || - PotentiallyImplementedAsEigenDot(*instruction, - target_machine_features_) || (opcode == HloOpcode::kFusion && instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || - ShapeUtil::IsTuple(instruction->shape())) { + instruction->shape().IsTuple()) { return 1; } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc index 722aa3120ef4d8c957873ac58c361f19632dde1f..a0667d0d9d1cde246f4b74626859955beeec08b0 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -15,12 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h" #include -#include #include -#include #include +#include #include -#include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/platform/dynamic_annotations.h" @@ -28,80 +26,14 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace { -using tensorflow::int16; using tensorflow::int32; using tensorflow::int64; -using tensorflow::int8; -using tensorflow::uint16; -using tensorflow::uint32; -using tensorflow::uint64; -using tensorflow::uint8; - -template -void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::sort(row_to_sort, row_to_sort + num_elements); -} - -// We would like a total order of floating point numbers so that the -// sort has a predictable behavior in the presence of NaNs. Rather -// than using floating point comparison, we use the following trick: -// If f is a float, and -// x = bit_cast(f); -// y = x < 0 ? 0x7FFFFFFF - x : x; -// then y is ordered as an int32 such that finite values have the -// obvious order, -0 is ordered before 0, and -NaN and NaN appear at -// the beginning and end of the ordering. -template -CastType Convert(KeyType value) { - CastType casted_value; - memcpy(&casted_value, &value, sizeof(CastType)); - if (casted_value < 0) { - return static_cast(std::numeric_limits::max()) - - casted_value; - } - return casted_value; -} - -template -bool LessThan(KeyType lhs, KeyType rhs) { - return Convert(lhs) < - Convert(rhs); -} - -template <> -void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::stable_sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan(lhs.first, rhs.first); - }); -} - -template <> -void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::stable_sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan(lhs.first, rhs.first); - }); -} - -template <> -void KeyValueSort(std::pair* row_to_sort, - int64 num_elements) { - std::stable_sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan( - Eigen::half_impl::half_to_float(lhs.first), - Eigen::half_impl::half_to_float(rhs.first)); - }); -} +} // namespace -template -void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, - int32 values_count, - int32* values_primitive_type_size_in_bytes) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( + int64 a, int64 b, int64 c, char** values, int32 values_count, + int32* values_primitive_type_size_in_bytes, + bool (*less_than)(char*, char*)) { // 'values' and 'values_primitive_type_size_in_bytes' are managed by the JIT // code, so msan can't tell they are initialized. TF_ANNOTATE_MEMORY_IS_INITIALIZED(values, values_count * sizeof(char*)); @@ -121,8 +53,8 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, int64 num_iteration_elements = a * c; int64 sort_dimension_offset = c; - std::unique_ptr[]> row_to_sort( - new std::pair[sort_dimension_elements]); + std::unique_ptr indices(new int64[sort_dimension_elements]); + std::iota(indices.get(), indices.get() + sort_dimension_elements, 0); std::unique_ptr reordered_values( new std::string[sort_dimension_elements]); for (int64 index = 0; index < num_iteration_elements; ++index) { @@ -135,24 +67,22 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, int64 base_offset = index % sort_dimension_offset + (index - index % sort_dimension_offset) * sort_dimension_elements; - // TODO(b/26783907): We could define a custom iterator class that references - // all arrays. Then we could avoid the intermediate copy. However this - // would become more complicated, and it is not clear if the benefit is high - // enough. - for (int64 i = 0; i < sort_dimension_elements; ++i) { - row_to_sort[i] = - std::make_pair(keys[base_offset + i * sort_dimension_offset], i); - } - KeyValueSort(row_to_sort.get(), sort_dimension_elements); - for (int64 i = 0; i < sort_dimension_elements; ++i) { - keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first; - } - - // Reorder the values according to the order defined by the keys. + std::stable_sort( + indices.get(), indices.get() + sort_dimension_elements, + [&](int64 a, int64 b) { + int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) * + values_primitive_type_size_in_bytes[0]; + int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) * + values_primitive_type_size_in_bytes[0]; + return less_than(values[0] + memory_index_lhs, + values[0] + memory_index_rhs); + }); + + // Reorder the values according to the order defined by 'indices'. for (int32 idx = 0; idx < values_count; ++idx) { for (int64 i = 0; i < sort_dimension_elements; ++i) { int64 memory_index = - (base_offset + row_to_sort[i].second * sort_dimension_offset) * + (base_offset + indices[i] * sort_dimension_offset) * values_primitive_type_size_in_bytes[idx]; reordered_values[i] = @@ -168,88 +98,3 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, } } } -} // namespace - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED( - bool* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8( - int8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8( - uint8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16( - int16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16( - uint16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16( - Eigen::half* keys, int64 a, int64 b, int64 c, char** values, - int32 values_count, int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32( - int32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32( - uint32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32( - float* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64( - int64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64( - uint64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64( - double* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h index 7821099386969e855ea1737cf53ef49c15c6e93b..5460af3485b94aaef1a5822a79e4fa325bcb67ea 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h @@ -21,76 +21,19 @@ limitations under the License. extern "C" { -// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b' -// dimension of 'keys' is sorted into ascending order. If 'values_count' is <= -// 0, 'values' and 'values_primitive_type_size_in_bytes' can be nullptr. -// If 'values_count' > 0, they contain exactly 'values_count' many elements. -// Each element of 'values' also represents a 3-dimensional shape with -// dimensions [a, b, c], and the size of the primitive type of the i-th shape -// has exactly 'values_primitive_type_size_in_bytes[i]' bytes. The elements in -// each 'values' shape are reordered in such a way that if the element at index -// 'i' in 'keys' was moved to index 'j', the element at index 'i' in a 'values' -// shape is also moved to index 'j' (which means that the same elements -// correspond to each other as before). -extern void __xla_cpu_runtime_KeyValueSortPRED( - bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, +// Each entry in 'values' represents a 3-dimensional shape with dimensions +// [a, b, c]. The 'b' dimension of the first shape is sorted into ascending +// order according to the results of comparisons using the provided 'less_than' +// function. 'values_count' must be > 0 and specifies the number of entries in +// 'values' and 'values_primitive_type_size_in_bytes'. The size of the primitive +// type of the i-th shape has exactly 'values_primitive_type_size_in_bytes[i]' +// bytes. The elements in each 'values' shape are reordered in the same way +// according to the comparisons using the first shape. +extern void __xla_cpu_runtime_KeyValueSort( + tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS8( - tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU8( - tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS16( - tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU16( - tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortF16( - Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS32( - tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU32( - tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortF32( - float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, - char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS64( - tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU64( - tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortF64( - double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, - char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); + tensorflow::int32* values_primitive_type_size_in_bytes, + bool (*less_than)(char*, char*)); } #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc index 1ed743afc30af7c7ff38c7d2a738f2e376270952..1f7204e67a413efabd34cd7d88ced4c82ee7a5df 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -20,6 +20,10 @@ limitations under the License. #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + using tensorflow::int32; using tensorflow::int64; diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 296f39a4853f2d3f7030209a921001e92c39d609..9c2685674fbc133de1220caef81ac3b60a1c0f7c 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -116,13 +116,26 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, orc_jit_memory_mapper::GetInstance()); result.Resolver = symbol_resolver_; return result; + }, + /*NotifyLoaded=*/ + llvm::orc::LegacyRTDyldObjectLinkingLayer::NotifyLoadedFtor(), + /*NotifyFinalized=*/ + [this](VModuleKeyT, const llvm::object::ObjectFile& object, + const llvm::RuntimeDyld::LoadedObjectInfo& object_info) { + this->NotifyObjectFinalized(object, object_info); + }, + /*NotifyFreed=*/ + [this](VModuleKeyT, const llvm::object::ObjectFile& object) { + this->NotifyObjectFreed(object); }), compile_layer_(object_layer_, CompilerFunctor(target_machine_.get(), &disassembler_, opt_level, optimize_for_size, enable_fast_math, disable_expensive_passes, std::move(pre_optimization_hook), - std::move(post_optimization_hook))) { + std::move(post_optimization_hook))), + gdb_jit_event_listener_( + llvm::JITEventListener::createGDBRegistrationListener()) { VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str() << " features: " << target_machine_->getTargetFeatureString().str(); } @@ -147,6 +160,20 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { return symbol_info; } +void SimpleOrcJIT::NotifyObjectFinalized( + const llvm::object::ObjectFile& object, + const llvm::RuntimeDyld::LoadedObjectInfo& object_info) { + uint64_t key = static_cast( + reinterpret_cast(object.getData().data())); + gdb_jit_event_listener_->notifyObjectLoaded(key, object, object_info); +} + +void SimpleOrcJIT::NotifyObjectFreed(const llvm::object::ObjectFile& object) { + uint64_t key = static_cast( + reinterpret_cast(object.getData().data())); + gdb_jit_event_listener_->notifyFreeingObject(key); +} + SimpleOrcJIT::VModuleKeyT SimpleOrcJIT::AddModule( std::unique_ptr module) { auto key = execution_session_.allocateVModule(); @@ -213,18 +240,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortPRED); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS8); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU8); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS16); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU16); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF16); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS32); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU32); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF32); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS64); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU64); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF64); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort); registry->Register("__gnu_f2h_ieee", reinterpret_cast(__gnu_f2h_ieee)); registry->Register("__gnu_h2f_ieee", reinterpret_cast(__gnu_h2f_ieee)); diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index 78406ba143570183aea09d79db3f9b708c21bf70..3307c2f93d796bbdcd49af7f68e9f6c388e402ca 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "llvm/ADT/Triple.h" +#include "llvm/ExecutionEngine/JITEventListener.h" #include "llvm/ExecutionEngine/Orc/Core.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" @@ -99,6 +100,11 @@ class SimpleOrcJIT { private: llvm::JITSymbol ResolveRuntimeSymbol(const std::string& name); + void NotifyObjectFinalized( + const llvm::object::ObjectFile& object, + const llvm::RuntimeDyld::LoadedObjectInfo& object_info); + void NotifyObjectFreed(const llvm::object::ObjectFile& object); + std::vector module_keys_; std::unique_ptr target_machine_; const Disassembler disassembler_; @@ -107,6 +113,15 @@ class SimpleOrcJIT { std::shared_ptr symbol_resolver_; ObjLayerT object_layer_; CompileLayerT compile_layer_; + + // Non owning pointer to a JIT event listener that registers the JIT events + // with an attached GDB. + // + // Note: we get a pointer to this event listener using + // `createGDBRegistrationListener` which makes it look like we're supposed to + // free this, but the function is poorly named and really just returns a + // pointer to a static object. + llvm::JITEventListener* gdb_jit_event_listener_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index f8f5f392da8ab3348e63185aecf7b639daacaa42..8b7f843582b697058fe328fe69990122d868ada4 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -16,7 +16,6 @@ limitations under the License. // Tests that we call into Eigen for dot operations as needed. #include -#include #include #include "absl/strings/str_cat.h" @@ -102,10 +101,10 @@ std::vector GetDotTestCases() { return result; } -INSTANTIATE_TEST_CASE_P(CpuEigenDotOperationTestInstantiation, - CpuEigenDotOperationTest, - ::testing::ValuesIn(GetDotTestCases()), - DotTestSpecToString); +INSTANTIATE_TEST_SUITE_P(CpuEigenDotOperationTestInstantiation, + CpuEigenDotOperationTest, + ::testing::ValuesIn(GetDotTestCases()), + DotTestSpecToString); } // namespace } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc index 5cc6d01c0f15d4209cbc1fb259a0078fb9957f6e..f0f897e9635600b22e0c389ba056899e4d6ab3d4 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -48,7 +48,7 @@ class InfeedTest : public ClientLibraryTestBase { ASSERT_IS_OK(client_->TransferToInfeed(literal)); XlaBuilder builder(TestName()); Infeed(&builder, literal.shape()); - if (ShapeUtil::IsTuple(literal.shape())) { + if (literal.shape().IsTuple()) { // TODO(b/30609564): Use ComputeAndCompareLiteral instead. ComputeAndCompareTuple(&builder, literal, {}); } else { diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index 9b10c49f4f547edfb2164f98c49cceb031148bdc..9078b8fd1ff6cb0ddac89d5fcd13a9ccfae07763 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include -#include #include +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" @@ -59,8 +59,9 @@ class CpuUnaryIntrinsicTest string features{spec.features.data(), spec.features.size()}; if (!features.empty()) { - std::replace_if(features.begin(), features.end(), - [](char c) { return c != '_' && !isalnum(c); }, '_'); + std::replace_if( + features.begin(), features.end(), + [](char c) { return c != '_' && !absl::ascii_isalnum(c); }, '_'); } else { features = ""; } @@ -140,10 +141,10 @@ IntrinsicTestSpec CpuUnaryIntrinsicTestCases[] = { HloOpcode::kLog, kTriple_android_arm, "", R"(CHECK: fadd fast <4 x float> )"}}; -INSTANTIATE_TEST_CASE_P(CpuUnaryIntrinsicTestInstantiation, - CpuUnaryIntrinsicTest, - ::testing::ValuesIn(CpuUnaryIntrinsicTestCases), - CpuUnaryIntrinsicTest::Name); +INSTANTIATE_TEST_SUITE_P(CpuUnaryIntrinsicTestInstantiation, + CpuUnaryIntrinsicTest, + ::testing::ValuesIn(CpuUnaryIntrinsicTestCases), + CpuUnaryIntrinsicTest::Name); } // namespace } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc new file mode 100644 index 0000000000000000000000000000000000000000..eb6c44b70ab34d0a294880b5de4fe0b3ba5e19e5 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc @@ -0,0 +1,1014 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h" + +#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" + +namespace xla { +namespace cpu { +namespace { + +using tensorflow::int64; + +// Provides tiled access to an in-memory rank 2 array. +class MemoryTile { + public: + // Constructs a MemoryTile that can operate on tiles consisting of + // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at + // `major_dim_offset` in the major dimension. The tile size along the minor + // dimension is the vector size, and that is implicitly determined by `vsl`. + MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* b, + llvm::Value* matrix, int64 matrix_size_along_minor_dim, + llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) + : vsl_(vsl), b_(b) { + pointers_.reserve(tile_size_along_major_dim); + for (int64 i = 0; i < tile_size_along_major_dim; i++) { + llvm::Value* total_offset = + b->CreateMul(b->getInt64(matrix_size_along_minor_dim), + b->CreateAdd(b->getInt64(i), major_dim_offset)); + pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset)); + } + } + + // Load a tile consisting of `tile_size_along_major_dim` vectors from position + // {major: `major_dim_offset`, minor: `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. + std::vector LoadTile(llvm::Value* minor_dim_offset) const { + std::vector result; + result.reserve(pointers_.size()); + for (const auto& pointer : pointers_) { + result.push_back(vsl_->LoadVector(pointer, minor_dim_offset)); + } + return result; + } + + // Stores `tile` to position {major: `major_dim_offset`, minor: + // `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. + void StoreTile(absl::Span tile, + llvm::Value* minor_dim_offset) const { + CHECK_EQ(tile.size(), pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset); + } + } + + // Loads a tile of size [`tile_size_along_major_dim`, + // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`, + // minor: `minor_dim_offset`} and then broadcasts each element into a vector + // of size vsl_.vector_size(). The (i,j)'th element of the return value is + // the (i,j)'th element in the tile broadcasted into an LLVM vector. + // + // Note: `major_dim_offset` is a parameter to the constructor. + std::vector> LoadBroadcastTile( + llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const { + std::vector> result; + result.resize(pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + for (int64 j = 0; j < tile_size_along_middle_dim; j++) { + result[i].push_back(vsl_->LoadBroadcast( + pointers_[i], b_->CreateAdd(minor_dim_offset, b_->getInt64(j)))); + } + } + return result; + } + + private: + VectorSupportLibrary* vsl_; + llvm::IRBuilder<>* b_; + std::vector pointers_; +}; + +// The base class for the classes representing the GEMV emitter configurations. +// +// The IR emitted (modulo the LLVM values representing the input and output +// buffers) by the row major and column major GEMV emitters should be a function +// of their configuration. This is important because their configuration is +// used as a key to cache the generated IR. +class GemvConfig { + public: + // Mixin for convenience. + template + struct User { + public: + PrimitiveType scalar_type() const { + return derived().config().scalar_type(); + } + int64 tile_rows() const { return derived().config().tile_rows(); } + int64 tile_cols() const { return derived().config().tile_cols(); } + int64 m() const { return derived().config().m(); } + int64 k() const { return derived().config().k(); } + int64 has_addend() const { return derived().config().has_addend(); } + + private: + const T& derived() const { return *static_cast(this); } + }; + + PrimitiveType scalar_type() const { return scalar_type_; } + int64 tile_rows() const { return tile_rows_; } + int64 tile_cols() const { return tile_cols_; } + int64 m() const { return m_; } + int64 k() const { return k_; } + bool has_addend() const { return has_addend_; } + + string GetCacheKey() const { + return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_", + tile_rows(), "_", tile_cols(), "_", m(), "_", k(), + has_addend() ? "_with_addend" : ""); + } + + protected: + explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, bool has_addend) + : name_(std::move(name)), + scalar_type_(scalar_type), + tile_rows_(tile_rows), + tile_cols_(tile_cols), + m_(m), + k_(k), + has_addend_(has_addend) {} + + private: + string name_; + PrimitiveType scalar_type_; + int64 tile_rows_; + int64 tile_cols_; + int64 m_; + int64 k_; + bool has_addend_; +}; + +// Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the +// layout of the vector does not matter). This implementation uses a tiling +// scheme to improve performance. +// +// We logically separate the LHS matrix into four segments: +// +// +----------------------+---+ +// | | | +// | | | +// | A | B | +// | | | +// | | | +// | | | +// +----------------------+---+ +// | C | D | +// +----------------------+---+ +// +// where A is the largest submatrix of the LHS that can be evenly dividied into +// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: +// +// +---+---+---+---+ +--+--+--+--+ +// |M00|M10|M20|M30| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M01|M11|M21|M31| and |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M02|M12|M22|M32| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M03|M13|M23|M33| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// +// (Legend: rows are horizontal and columns are vertical; and each column is one +// llvm::Value of a vector type) +// +// where: +// +// a. The left tile is from the column major left matrix. +// b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3] +// vector loaded from the RHS vector. +// +// As we iterate through the column dimension, we compute the change to the +// result vector by an elementwise multiplication between the two tiles above +// followed by a reduction along the major dimension: +// +// +-----------------------------------+ +// | M00*V0 + M10*V1 + M20*V2 + M30*V3 | +// +-----------------------------------+ +// | M01*V0 + M11*V1 + M21*V2 + M31*V3 | +// Result[R:R+4] += +-----------------------------------+ +// | M02*V0 + M12*V1 + M22*V2 + M32*V3 | +// +-----------------------------------+ +// | M03*V0 + M13*V1 + M23*V2 + M33*V3 | +// +-----------------------------------+ +// +// Where R is the starting row for the tile. +// +// We have an inner epilogue loop to deal with the "C" submatrix and an outer +// epilogue loop to deal with the B,D submarix. +// +// TODO(sanjoy): We should investigate if using gather loads and scatter stores +// can be used here have the same inner loop for both column-major and row-major +// matrix-vector products. +class ColumnMajorMatrixVectorProductEmitter + : public GemvConfig::User { + public: + class Config : public GemvConfig { + public: + explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, + int64 m, int64 k, bool has_addend) + : GemvConfig(/*name=*/"col_major_gemv", scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, + /*k=*/k, /*has_addend=*/has_addend) {} + }; + + ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, + llvm::IRBuilder<>* b) + : config_(config), + lhs_(lhs), + rhs_(rhs), + addend_(addend), + result_(result), + b_(b), + ksl_(b_), + vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), b_, "") { + CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast(tile_rows()))); + CHECK(!has_addend() || addend != nullptr); + } + + void Emit(); + + const Config& config() const { return config_; } + + private: + void EmitOuterLoopBody(llvm::Value* column, int64 column_count, + bool is_first_column); + + MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) { + return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/m(), + /*major_dim_offset=*/column_start, + /*tile_size_along_major_dim=*/column_count); + } + + // Load a tile of values from the RHS. For the RHS a "tile" is a contiguous + // sequence of `count` values, each one broadcasted to the vector width. + std::vector LoadRhsTile(llvm::Value* offset, int64 count) { + llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset); + std::vector result; + result.reserve(count); + for (int64 i = 0; i < count; i++) { + result.push_back(vsl_.LoadBroadcast(base_pointer, i)); + } + return result; + } + + void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, + const std::vector& rhs_tile, + int64 columns, bool is_first_column); + + void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns, + bool is_first_tiled_column); + + Config config_; + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* addend_; + llvm::Value* result_; + llvm::IRBuilder<>* b_; + KernelSupportLibrary ksl_; + VectorSupportLibrary vsl_; +}; + +void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody( + llvm::Value* column, int64 column_count, bool is_first_column) { + MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column, + /*column_count=*/column_count); + + std::vector rhs_tile = + LoadRhsTile(column, /*count=*/column_count); + EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile, + /*columns=*/column_count, is_first_column); + EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column); +} + +void ColumnMajorMatrixVectorProductEmitter::Emit() { + // See the comment on the class declaration for the algorithm used here. + int64 column_remainder = k() % tile_cols(); + int64 column_limit = k() - column_remainder; + + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), + [&](llvm::Value* column, bool is_first_column) { + EmitOuterLoopBody(column, tile_cols(), is_first_column); + }); + + if (column_remainder != 0) { + EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder, + column_limit == 0); + } +} + +void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( + MemoryTile* lhs_memory_tile, const std::vector& rhs_tile, + int64 columns, bool is_first_column) { + int64 row_limit = m() - (m() % tile_rows()); + + ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit, + /*step=*/tile_rows(), [&](llvm::Value* row) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row); + llvm::Value* accumulator = + is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) + : vsl_.GetZeroVector()) + : vsl_.LoadVector(result_, row); + for (int i = 0; i < columns; i++) { + accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); + } + vsl_.StoreVector(accumulator, result_, row); + }); +} + +void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( + llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) { + int64 row_start = m() - (m() % tile_rows()); + if (row_start == m()) { + return; + } + + llvm::Value* columns_llvm = b_->getInt64(columns); + + // for (col = current_tile_col; col < (columns + current_tile_col); col++) + // for (row = row_start, row < m_; row++) { + // result[row] += lhs[row, col] * rhs[col] + // // Also take into account that if col is 0 then result[row] is not + // // initialized. + // } + + ksl_.For( + "dot.inner.epilg.outer", /*start=*/current_tile_col, + /*end=*/b_->CreateAdd(columns_llvm, current_tile_col), + /*step=*/1, /*peel_first_iteration=*/false, + [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { + llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); + llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m())); + llvm::Value* lhs_base_pointer = + vsl_.ComputeOffsetPointer(lhs_, total_offset); + ksl_.For( + "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(), + /*step=*/1, [&](llvm::Value* scalar_row) { + llvm::Value* product = vsl_.Mul( + vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); + llvm::Value* setting_result_first_time = b_->CreateAnd( + is_first_scalar_col, b_->getInt1(is_first_tiled_column)); + ksl_.If( + setting_result_first_time, + /*true_block_generator=*/ + [&]() { + if (addend_) { + vsl_.StoreScalar( + vsl_.Add(vsl_.LoadScalar(addend_, scalar_row), + product), + result_, scalar_row); + } else { + vsl_.StoreScalar(product, result_, scalar_row); + } + }, + /*false_block_generator=*/ + [&]() { + vsl_.StoreScalar( + vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product), + result_, scalar_row); + }); + }); + }); +} + +// Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the +// layout of the vector does not matter). This implementation uses a tiling +// scheme to improve performance. +// +// We logically separate the LHS matrix into four segments: +// +// +----------------------+---+ +// | | | +// | | | +// | A | B | +// | | | +// | | | +// | | | +// +----------------------+---+ +// | C | D | +// +----------------------+---+ +// +// where A is the largest submatrix of the LHS that can be evenly dividied into +// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: +// +// +---+---+---+---+ +// |M00|M10|M20|M30| +// +---+---+---+---+ +--+--+--+--+ +// |M01|M11|M21|M31| and |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M02|M12|M22|M32| +// +---+---+---+---+ +// |M03|M13|M23|M33| +// +---+---+---+---+ +// +// (Legend: rows are horizontal and columns are vertical; and each row is one +// llvm::Value of a vector type) +// +// where: +// +// a. The left tile is loaded from the row major left matrix. +// b. The right vector is loaded from the RHS vector. +// +// We keep 4 vector accumulators accumulating the following four vector +// expressions as we iterate over the row dimension: +// +// +------+------+------+------+ +// |M0I*V0|M1I*V1|M2I*V2|M3I*V3| for I in [0,4) +// +------+------+------+------+ +// +// In the end we do a horizontal reduction over these 4 vector accumulators to +// get 4 values in the result vector. +// +// We have an inner epilogue loop to deal with the "B" sub-matrix and an outer +// epilogue loop to deal with the C,D submatrix. +class RowMajorMatrixVectorProductEmitter + : public GemvConfig::User { + public: + class Config : public GemvConfig { + public: + explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, + int64 m, int64 k, bool has_addend) + : GemvConfig(/*name=*/"row_major_gemv", scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, + /*k=*/k, /*has_addend=*/has_addend) {} + }; + + RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b) + : config_(config), + lhs_(lhs), + rhs_(rhs), + addend_(addend), + result_(result), + b_(b), + ksl_(b_), + vsl_(scalar_type(), /*vector_size=*/tile_cols(), b_, "") { + CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast(tile_cols()))); + CHECK(!has_addend() || addend != nullptr); + } + + void Emit(); + + const Config& config() const { return config_; } + + private: + MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) { + return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/k(), + /*major_dim_offset=*/row_start, + /*tile_size_along_major_dim=*/row_count); + } + + void EmitOuterLoopBody(llvm::Value* row, int64 row_count); + + void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64 rows, + std::vector* vector_accumulators); + + void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows, + std::vector* scalar_accumulators); + + Config config_; + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* addend_; + llvm::Value* result_; + llvm::IRBuilder<>* b_; + KernelSupportLibrary ksl_; + VectorSupportLibrary vsl_; +}; + +void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, + int64 row_count) { + MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row, + /*row_count=*/row_count); + std::vector vector_accumulators; + std::vector scalar_accumulators; + for (int i = 0; i < row_count; i++) { + vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector()); + scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar()); + } + EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count, + &vector_accumulators); + EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, + &scalar_accumulators); + + std::vector accumulator_values; + std::transform( + vector_accumulators.begin(), vector_accumulators.end(), + std::back_inserter(accumulator_values), + [](const VectorVariable& vector_var) { return vector_var.Get(); }); + + std::vector horizontal_sums; + if (row_count == vsl_.vector_size()) { + if (addend_) { + horizontal_sums = vsl_.ComputeHorizontalSums( + std::move(accumulator_values), vsl_.LoadVector(addend_, row)); + } else { + horizontal_sums = + vsl_.ComputeHorizontalSums(std::move(accumulator_values)); + } + } else { + horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values)); + } + + for (int i = 0; i < row_count; i++) { + llvm::Value* result_value = + vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get()); + llvm::Value* offset = b_->CreateAdd(b_->getInt64(i), row); + if (addend_ && row_count != vsl_.vector_size()) { + result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value); + } + vsl_.StoreScalar(result_value, result_, offset); + } +} + +void RowMajorMatrixVectorProductEmitter::Emit() { + // See the comment on the class declaration for the algorithm used here. + int64 row_remainder = m() % tile_rows(); + int64 row_limit = m() - row_remainder; + + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), + [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); + + if (row_remainder != 0) { + EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder); + } +} + +void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( + MemoryTile* lhs_memory_tile, int64 rows, + std::vector* vector_accumulators) { + int64 column_limit = k() - (k() % tile_cols()); + + ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, + /*step=*/tile_cols(), [&](llvm::Value* col) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); + llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); + for (int i = 0; i < rows; i++) { + llvm::Value* old_sum = (*vector_accumulators)[i].Get(); + (*vector_accumulators)[i].Set( + vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); + } + }); +} + +void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( + llvm::Value* current_tile_row, int64 rows, + std::vector* scalar_accumulators) { + int64 column_start = k() - (k() % tile_cols()); + if (column_start == k()) { + return; + } + + for (int r = 0; r < rows; r++) { + llvm::Value* total_offset = b_->CreateMul( + b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k())); + llvm::Value* lhs_base_pointer = + vsl_.ComputeOffsetPointer(lhs_, total_offset); + ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), + /*step=*/1, [&](llvm::Value* scalar_col) { + llvm::Value* product = + vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), + vsl_.LoadScalar(rhs_, scalar_col)); + llvm::Value* old_value = (*scalar_accumulators)[r].Get(); + (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); + }); + } +} + +// This class implements a tiled matrix multiplication algorithm, intended for +// multiplying small matrices that don't need cache tiling. +// +// In the future this can be used as the innermost GEBP loop in a GEMM kernel as +// described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of +// high-performance matrix multiplication." ACM Transactions on Mathematical +// Software (TOMS) 34.3 (2008): 12.". +// +// This only supports canonical dot operations (i.e. where the lhs contraction +// dimension is 1 and the rhs contraction dimension is 0) over row major +// matrices. +class TiledSmallGemmEmitter { + public: + // Describe the dimensions of the kernel. + class Dimensions { + public: + explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {} + + int64 m() const { return m_; } + int64 k() const { return k_; } + int64 n() const { return n_; } + + string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); } + + private: + const int64 m_; + const int64 k_; + const int64 n_; + }; + + // Represents the configuration of the emitter. The LLVM IR emitted by the + // emitter, modulo the LLVM values holding the input and output buffers, must + // be a function of the instance of `Config` passed to it. + // + // `dims` holds the matrix multiplication dimensions. + // + // `max_vectorization_width` is the maximum vector width (i.e. the width of + // the largest vector register we will use). This can be larger than the + // largest vector register supported by the machine -- LLVM will legalize + // these large vector widths into legally sized vectors. + // + // `max_vector_count` is the maximum number of vectors of size + // `max_vectorization_width` that we will attempt to process at once. + // + // `min_vectorization_width` is the smallest vector width the emitter will use + // -- below that it will devolve to using a scalar loop. + // + // The innermost reduction loop executes the matrix multiply in tiles of size + // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`, + // ] in the RHS. + class Config { + public: + explicit Config(PrimitiveType scalar_type, Dimensions dims, + int64 max_vectorization_width, int64 max_vector_count, + int64 min_vectorization_width, int64 tile_size_m, + int64 tile_size_k) + : scalar_type_(scalar_type), + dims_(dims), + max_vectorization_width_(max_vectorization_width), + max_vector_count_(max_vector_count), + min_vectorization_width_(min_vectorization_width), + tile_size_m_(tile_size_m), + tile_size_k_(tile_size_k) {} + + string GetCacheKey() const { + return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_", + dims().ToString(), "_", max_vectorization_width(), + "_", min_vectorization_width(), "_", tile_size_m(), + "_", tile_size_k()); + } + + PrimitiveType scalar_type() const { return scalar_type_; } + Dimensions dims() const { return dims_; } + int64 max_vectorization_width() const { return max_vectorization_width_; } + int64 max_vector_count() const { return max_vector_count_; } + int64 min_vectorization_width() const { return min_vectorization_width_; } + + int64 tile_size_m() const { return tile_size_m_; } + int64 tile_size_k() const { return tile_size_k_; } + + private: + PrimitiveType scalar_type_; + Dimensions dims_; + int64 max_vectorization_width_; + int64 max_vector_count_; + int64 min_vectorization_width_; + int64 tile_size_m_; + int64 tile_size_k_; + }; + + // Creates an instance of TiledSmallGemmEmitter that matrix-multiplies + // `lhs` with `rhs` and stores the result in `result`. + explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* result, + llvm::IRBuilder<>* b) + : lhs_(lhs), + rhs_(rhs), + result_(result), + config_(config), + b_(b), + ksl_(b_) { + CHECK(max_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(max_vectorization_width()))); + CHECK_GT(max_vector_count(), 0); + CHECK(min_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(min_vectorization_width()))); + CHECK_GE(max_vectorization_width(), min_vectorization_width()); + CHECK_GT(tile_size_k(), 0); + } + + void Emit(); + + private: + // The HandleResiduesOnX helpers split the iteration space for dimension X + // into a multiple of the tile size on dimension X and an epilogue. These + // helpers ultimately call into `EmitTiledGemm` for emitting the + // tiled GEMM kernel. + + void HandleResiduesOnN(); + void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start, + llvm::Value* n_end); + void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end); + + // This emits a tiled GEMM kernel. For a detailed description see the comment + // on the implementation. + void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, + llvm::Value* m_end); + + llvm::Value* GetInt64(int64 value) { return b_->getInt64(value); } + + Config config() const { return config_; } + Dimensions dims() const { return config().dims(); } + + int64 max_vectorization_width() const { + return config().max_vectorization_width(); + } + int64 max_vector_count() const { return config().max_vector_count(); } + int64 min_vectorization_width() const { + return config().min_vectorization_width(); + } + int64 tile_size_m() const { return config().tile_size_m(); } + int64 tile_size_k() const { return config().tile_size_k(); } + PrimitiveType scalar_type() const { return config().scalar_type(); } + + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* result_; + Config config_; + + llvm::IRBuilder<>* b_; + KernelSupportLibrary ksl_; +}; + +void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); } + +void TiledSmallGemmEmitter::HandleResiduesOnN() { + // We can only iterate the `n` dimension for an extent that is divisible by + // the vectorization width. So we emit an outer loop that first processes the + // largest extent in `n` that is divisible by max_vectorization_width, then + // the largest remaining extent that is divisible by max_vectorization_width / + // 2 etc. + + int64 current_vectorization_width = + max_vector_count() * max_vectorization_width(); + int64 current_vector_count = max_vector_count(); + + int64 n_start = 0; + while (n_start != dims().n() && + current_vectorization_width >= min_vectorization_width()) { + int64 n_end = dims().n() - (dims().n() % current_vectorization_width); + if (n_start != n_end) { + VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_, + "gemm"); + HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end)); + n_start = n_end; + } + if (current_vector_count == 1) { + current_vectorization_width /= 2; + } else { + current_vector_count--; + current_vectorization_width = + current_vector_count * max_vectorization_width(); + } + } + + if (n_start != dims().n()) { + VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm"); + ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { + llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1)); + HandleResiduesOnK(&vsl, n_i, n_i_next); + }); + } +} + +void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, + llvm::Value* n_start, + llvm::Value* n_end) { + int64 k_start = 0; + int64 k_end = dims().k() - (dims().k() % tile_size_k()); + if (k_end != k_start) { + HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end), + n_start, n_end); + k_start = k_end; + } + + if (k_start != dims().k()) { + HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start), + GetInt64(dims().k()), n_start, n_end); + } +} + +void TiledSmallGemmEmitter::HandleResiduesOnM( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { + const int64 m_end = dims().m() - dims().m() % tile_size_m(); + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(), + GetInt64(0), GetInt64(m_end)); + + if (m_end != dims().m()) { + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, + dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m())); + } +} + +// The loop structure is: +// +// Iterate over dimension M as m: +// Iterate over dimension N as n: +// Iterate over dimension K as k: +// OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n]) +// +// I.e. a just a tiled version of a "naive" GEMM. +// +// The tiling scheme is as follows: +// +// Let the LHS be: +// +// +----+----+----+ +// | a0 | b0 | c0 | . +// +----+----+----+ . +// | a1 | b1 | c1 | . +// +----+----+----+ +// .. .. +// +// and the RHS be: +// +// +----+----+----+----+ +// | p0 | p1 | p2 | p3 | . +// +----+----+----+----+ . +// | q0 | q1 | q2 | q3 | . +// +----+----+----+----+ +// | r0 | r1 | r2 | r3 | . +// +----+----+----+----+ . +// ...... ...... +// +// and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted +// by `vsl`) be 4. Then we want to matrix multiply this tile to get a [2,4] +// matrix that we can increment the result matrix by. +// +// First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank +// 3 array, L, of dimension [2,3,4]: +// +// L[0,_,_] * L[1,_,_] +// * +// +----+----+----+----+ * +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | a1 | a1 | a1 | a1 | +// +----+----+----+----+ * +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | b1 | b1 | b1 | b1 | +// +----+----+----+----+ * +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | c1 | c1 | c1 | c1 | +// +----+----+----+----+ * +----+----+----+----+ +// +// +// Then we FMA L[0,_,_] with the RHS to get the first row of the result and +// L[1,_,_] with the RHS to get the second row of the result. For example, +// L[0,_,_] is computed as: +// +// +----+----+----+----+ +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 | + +// +----+----+----+----+ +----+----+----+----+ +// +// +----+----+----+----+ +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 | + +// +----+----+----+----+ +----+----+----+----+ +// +// +----+----+----+----+ +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 | +// +----+----+----+----+ +----+----+----+----+ +// +// to get: +// +// +-------------------+-------------------+-------------------+--------- +// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... +// +-------------------+-------------------+-------------------+--------- +void TiledSmallGemmEmitter::EmitTiledGemm( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { + ksl_.For("dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { + MemoryTile result_memory_tile(vsl, b_, /*matrix=*/result_, + /*matrix_size_along_minor_dim=*/dims().n(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + MemoryTile lhs_memory_tile(vsl, b_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/dims().k(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + ksl_.For( + "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { + TileVariable result_tile_var(vsl, result_memory_tile.LoadTile(n_i)); + ksl_.For("dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { + MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i, + tile_size_k); + std::vector> lhs_tile = + lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); + std::vector rhs_tile = rhs_memory_tile.LoadTile(n_i); + std::vector result_tile = result_tile_var.Get(); + for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) { + for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) { + result_tile[r_m_i] = + vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i], + result_tile[r_m_i]); + } + } + result_tile_var.Set(result_tile); + }); + + result_memory_tile.StoreTile(result_tile_var.Get(), n_i); + }); + }); +} + +} // namespace + +void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b, + bool enable_fast_math, bool optimize_for_size) { + RowMajorMatrixVectorProductEmitter::Config config( + /*scalar_type=*/scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, + /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, + rhs, addend, result, + [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result) { + RowMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend, + result, b); + emitter.Emit(); + }); +} + +void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b, + bool enable_fast_math, bool optimize_for_size) { + ColumnMajorMatrixVectorProductEmitter::Config config( + /*scalar_type=*/scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, + /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, + rhs, addend, result, + [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result) { + ColumnMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend, + result, b); + emitter.Emit(); + }); +} + +void EmitSmallGemm(PrimitiveType scalar_type, int64 m, int64 k, int64 n, + int64 max_vectorization_width, int64 max_vector_count, + int64 min_vectorization_width, int64 tile_size_m, + int64 tile_size_k, llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* result, llvm::IRBuilder<>* b, + bool enable_fast_math, bool optimize_for_size) { + TiledSmallGemmEmitter::Config config( + /*scalar_type=*/scalar_type, + TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, + /*max_vectorization_width=*/max_vectorization_width, + /*max_vector_count=*/max_vector_count, + /*min_vectorization_width=*/min_vectorization_width, + /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, + rhs, result, + [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result) { + TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs, + /*rhs=*/rhs, + /*result=*/result, b); + small_gemm_emitter.Emit(); + }); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h new file mode 100644 index 0000000000000000000000000000000000000000..0a82326cc3704bce8c122261383249c60eda1f3a --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TILED_DOT_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TILED_DOT_EMITTER_H_ + +#include "llvm/IR/IRBuilder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace cpu { + +// These routines emit LLVM IR implementing tiled GEMM and GEMV routines. + +void EmitRowMajorGemv(PrimitiveType scalar_type, tensorflow::int64 tile_rows, + tensorflow::int64 tile_cols, tensorflow::int64 m, + tensorflow::int64 k, llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* addend, llvm::Value* result, + llvm::IRBuilder<>* b, bool enable_fast_math, + bool optimize_for_size); + +void EmitColumnMajorGemv(PrimitiveType scalar_type, tensorflow::int64 tile_rows, + tensorflow::int64 tile_cols, tensorflow::int64 m, + tensorflow::int64 k, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b, + bool enable_fast_math, bool optimize_for_size); + +void EmitSmallGemm(PrimitiveType scalar_type, tensorflow::int64 m, + tensorflow::int64 k, tensorflow::int64 n, + tensorflow::int64 max_vectorization_width, + tensorflow::int64 max_vector_count, + tensorflow::int64 min_vectorization_width, + tensorflow::int64 tile_size_m, tensorflow::int64 tile_size_k, + llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, + llvm::IRBuilder<>* b, bool enable_fast_math, + bool optimize_for_size); + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TILED_DOT_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc index 825e1436f0ec6d49b555e5e3e9c2c7a19fb7b062..70173d43d79e931b75f131ad380ad98359cc78b8 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc @@ -73,15 +73,14 @@ ENTRY TestComputation { abs = f32[] abs(arg) add = f32[] add(arg, gte) broadcast = f32[42] broadcast(add), dimensions={} - slice = f32[0] slice(broadcast), slice={[1:2]} + slice = f32[1] slice(broadcast), slice={[1:2]} copy = f32[] copy(arg) eq = pred[] equal-to(arg, gte) neg = f32[] negate(arg) ROOT convert = f64[] convert(f32[] arg) })"; std::unique_ptr module = - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()) - .ConsumeValueOrDie(); + ParseAndReturnVerifiedModule(hlo_string).ConsumeValueOrDie(); ElementwiseTestVisitor visitor; TF_EXPECT_OK(module->entry_computation()->Accept(&visitor)); } diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index b2ba2617902104bfea06713332fa1c2aedea536d..e8bc6d05716a2ef02e0280e86c7df4ac22fe78c4 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dot_decomposer.h" +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -156,29 +158,192 @@ Status DecomposeBatchDot(HloInstruction* dot) { return computation->ReplaceInstruction(dot, new_dot); } +// Convert a dot into a canonical form where non-contracting and contracting +// dimensions are reshaped together and batch dimensions are the most major +// dimensions. The requires transposing and reshapes the lhs and rhs and +// reshaping the output batch to the original shape. +Status CanonicalizeDot(HloInstruction* original_dot) { + auto computation = original_dot->parent(); + const auto& original_dnums = original_dot->dot_dimension_numbers(); + const int64 num_batch_dims = original_dnums.lhs_batch_dimensions_size(); + const int64 num_contracting_dims = + original_dnums.lhs_contracting_dimensions_size(); + + const auto& lhs_shape = original_dot->operand(0)->shape(); + const int64 lhs_rank = lhs_shape.rank(); + const int64 num_lhs_non_contracting_dims = + lhs_rank - num_batch_dims - num_contracting_dims; + + std::vector lhs_non_contracting_dims; + lhs_non_contracting_dims.reserve(num_lhs_non_contracting_dims); + int64 lhs_contracting_size = 1; + int64 lhs_non_contracting_size = 1; + std::vector batch_dim_sizes; + batch_dim_sizes.reserve(num_batch_dims); + for (int64 i = 0; i < lhs_rank; ++i) { + if (absl::c_linear_search(original_dnums.lhs_contracting_dimensions(), i)) { + lhs_contracting_size *= lhs_shape.dimensions(i); + } else if (absl::c_linear_search(original_dnums.lhs_batch_dimensions(), + i)) { + batch_dim_sizes.push_back(lhs_shape.dimensions(i)); + } else { + lhs_non_contracting_dims.push_back(i); + lhs_non_contracting_size *= lhs_shape.dimensions(i); + } + } + // The canonical form of the lhs is + // [BatchDims, NonContractingDims, ContractingsDims] + std::vector lhs_transpose; + lhs_transpose.reserve(lhs_rank); + lhs_transpose.insert(lhs_transpose.end(), + original_dnums.lhs_batch_dimensions().begin(), + original_dnums.lhs_batch_dimensions().end()); + lhs_transpose.insert(lhs_transpose.end(), lhs_non_contracting_dims.begin(), + lhs_non_contracting_dims.end()); + lhs_transpose.insert(lhs_transpose.end(), + original_dnums.lhs_contracting_dimensions().begin(), + original_dnums.lhs_contracting_dimensions().end()); + HloInstruction* transposed_lhs = + computation->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions(InversePermutation(lhs_transpose), + lhs_shape), + original_dot->mutable_operand(0), lhs_transpose)); + std::vector lhs_reshape_dims = batch_dim_sizes; + lhs_reshape_dims.push_back(lhs_non_contracting_size); + lhs_reshape_dims.push_back(lhs_contracting_size); + // Reshape the contracting and non-contracting dimensions together. + HloInstruction* reshaped_lhs = + computation->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(lhs_shape.element_type(), lhs_reshape_dims), + transposed_lhs)); + + const auto& rhs_shape = original_dot->operand(1)->shape(); + const int64 rhs_rank = rhs_shape.rank(); + const int64 num_rhs_non_contracting_dims = + rhs_rank - num_batch_dims - num_contracting_dims; + std::vector rhs_non_contracting_dims; + rhs_non_contracting_dims.reserve(num_rhs_non_contracting_dims); + int64 rhs_non_contracting_size = 1; + int64 rhs_contracting_size = 1; + for (int64 i = 0; i < rhs_rank; ++i) { + if (absl::c_linear_search(original_dnums.rhs_contracting_dimensions(), i)) { + rhs_contracting_size *= rhs_shape.dimensions(i); + } else if (!absl::c_linear_search(original_dnums.rhs_batch_dimensions(), + i)) { + rhs_non_contracting_dims.push_back(i); + rhs_non_contracting_size *= rhs_shape.dimensions(i); + } + } + + // The canonical form of the rhs is + // [BatchDims, ContractingsDims, NonContractingDims] + std::vector rhs_transpose; + rhs_transpose.reserve(rhs_rank); + rhs_transpose.insert(rhs_transpose.end(), + original_dnums.rhs_batch_dimensions().begin(), + original_dnums.rhs_batch_dimensions().end()); + rhs_transpose.insert(rhs_transpose.end(), + original_dnums.rhs_contracting_dimensions().begin(), + original_dnums.rhs_contracting_dimensions().end()); + rhs_transpose.insert(rhs_transpose.end(), rhs_non_contracting_dims.begin(), + rhs_non_contracting_dims.end()); + HloInstruction* transposed_rhs = + computation->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions(InversePermutation(rhs_transpose), + rhs_shape), + original_dot->mutable_operand(1), rhs_transpose)); + + std::vector rhs_reshape_dims = batch_dim_sizes; + rhs_reshape_dims.push_back(rhs_contracting_size); + rhs_reshape_dims.push_back(rhs_non_contracting_size); + // Reshape the contracting and non-contracting dimensions together. + HloInstruction* reshaped_rhs = + computation->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims), + transposed_rhs)); + + std::vector dot_dims = batch_dim_sizes; + dot_dims.push_back(lhs_non_contracting_size); + dot_dims.push_back(rhs_non_contracting_size); + + DotDimensionNumbers dot_dnums; + for (int64 i = 0; i < num_batch_dims; ++i) { + dot_dnums.add_lhs_batch_dimensions(i); + dot_dnums.add_rhs_batch_dimensions(i); + } + dot_dnums.add_lhs_contracting_dimensions(num_batch_dims + 1); + dot_dnums.add_rhs_contracting_dimensions(num_batch_dims); + + HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(original_dot->shape().element_type(), dot_dims), + reshaped_lhs, reshaped_rhs, dot_dnums, original_dot->precision_config())); + + return computation->ReplaceInstruction( + original_dot, computation->AddInstruction(HloInstruction::CreateReshape( + original_dot->shape(), dot))); +} + } // namespace StatusOr DotDecomposer::Run(HloModule* module) { XLA_VLOG_LINES(2, "DotDecomposer ENTRY\n" + module->ToString()); - // Gather all batch Dot operations. - std::vector batch_dots; + // Gather all Non-canonical Dot operations. + std::vector non_canonical_dots; for (auto* computation : module->MakeNonfusionComputations()) { for (auto* instruction : computation->instructions()) { if (instruction->opcode() != HloOpcode::kDot) { continue; } const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); - if (dnums.lhs_batch_dimensions_size() > 0 && decompose_batch_dot_) { - batch_dots.push_back(instruction); + // A dot it not canonical if there are more than one contracting + // dimension. + if (dnums.lhs_contracting_dimensions_size() > 1) { + non_canonical_dots.push_back(instruction); + continue; + } + if (dnums.lhs_batch_dimensions().empty() && + dnums.lhs_contracting_dimensions().empty()) { + non_canonical_dots.push_back(instruction); + continue; + } + if (dnums.lhs_batch_dimensions().empty()) { + continue; + } + std::vector canonical_batch_dims( + dnums.lhs_batch_dimensions_size()); + absl::c_iota(canonical_batch_dims, 0); + if (!absl::c_equal(dnums.lhs_batch_dimensions(), canonical_batch_dims) || + !absl::c_equal(dnums.rhs_batch_dimensions(), canonical_batch_dims)) { + non_canonical_dots.push_back(instruction); } } } - // Decompose each batch Dot in 'batch_dots'. bool changed = false; - for (auto* dot : batch_dots) { - TF_RETURN_IF_ERROR(DecomposeBatchDot(dot)); + for (auto* dot : non_canonical_dots) { + TF_RETURN_IF_ERROR(CanonicalizeDot(dot)); changed = true; } + + if (decompose_batch_dot_) { + std::vector batch_dots; + for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kDot) { + continue; + } + const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); + if (!dnums.lhs_batch_dimensions().empty()) { + batch_dots.push_back(instruction); + } + } + } + // Decompose each batch Dot in 'batch_dots'. + + for (auto* dot : batch_dots) { + TF_RETURN_IF_ERROR(DecomposeBatchDot(dot)); + changed = true; + } + } XLA_VLOG_LINES(2, "DotDecompose EXIT\n" + module->ToString()); return changed; } diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 6d0472689bf48092ceef2e9792c1358687d707ec..2b158d7a6ec510ce4cbc56bddc5cca71ac4f14f4 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -173,7 +173,7 @@ Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) { // Find out the new dynamic dimension after reduce. int64 dimensions_not_reduced_count = 0; - for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { + for (int i = 0; i < operand->shape().rank(); ++i) { if (dimension == i) { parent_->SetDynamicSize(reduce, {}, dimensions_not_reduced_count, dynamic_size); @@ -207,7 +207,7 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) { result_dim_mapping[i] = current_result_dims++; } - for (int64 i = 0; i < ShapeUtil::Rank(dot->operand(0)->shape()); i++) { + for (int64 i = 0; i < dot->operand(0)->shape().rank(); i++) { if (!absl::c_linear_search( dimension_numbers.lhs_contracting_dimensions(), i)) { if (operand_index == 0) { @@ -217,7 +217,7 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) { } } - for (int64 i = 0; i < ShapeUtil::Rank(dot->operand(1)->shape()); i++) { + for (int64 i = 0; i < dot->operand(1)->shape().rank(); i++) { if (!absl::c_linear_search( dimension_numbers.rhs_contracting_dimensions(), i) && !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), @@ -433,7 +433,7 @@ Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension( /* static */ StatusOr DynamicDimensionInference::Run( HloModule* module) { - VLOG(0) << "Param Config " << module->dynamic_parameter_binding().ToString(); + VLOG(2) << "Param Config " << module->dynamic_parameter_binding().ToString(); DynamicDimensionInference inference(module); TF_RETURN_IF_ERROR(inference.AnalyzeDynamicDimensions()); return inference; diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc index 1dd196821c05cc820e2a3bf53a04d96b15484cd4..b42e67b4bbcf731d89dd8af9e46b405235a92d8a 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -62,6 +62,17 @@ class DynamicDimensionInferenceTest : public HloTestBase { return module_->AddEmbeddedComputation(embedded_builder.Build()); } + HloComputation* GetGe() { + auto embedded_builder = HloComputation::Builder("ge"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGe, lhs, rhs)); + return module_->AddEmbeddedComputation(embedded_builder.Build()); + } + std::unique_ptr module_; std::unique_ptr inference_; const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {}); @@ -487,7 +498,7 @@ TEST_F(DynamicDimensionInferenceTest, SelectAndScatterTest) { // Test the ability to trace select and scatter batch dimensions. auto builder = HloComputation::Builder(TestName()); auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4}); - auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + auto source_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); Window window; // First dimension is unchanged. @@ -514,22 +525,26 @@ TEST_F(DynamicDimensionInferenceTest, SelectAndScatterTest) { /*parameter_number=*/0, input_shape, "A")); auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, scalar_shape_, "size_param")); + auto* source = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, source_shape, "B")); auto init = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); - auto* reduce_window = - builder.AddInstruction(HloInstruction::CreateReduceWindow( - output_shape, a_param, init, window, GetAdd())); + auto* sns = builder.AddInstruction(HloInstruction::CreateSelectAndScatter( + input_shape, a_param, GetGe(), window, source, init, GetAdd())); module_->AddEntryComputation(builder.Build()); TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 0})); + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{2, {}, 0})); TF_ASSERT_OK(RunInference()); - EXPECT_EQ(inference_->GetDynamicSize(reduce_window, {}, 0), size_param); + EXPECT_EQ(inference_->GetDynamicSize(sns, {}, 0), size_param); } } // namespace diff --git a/tensorflow/compiler/xla/service/dynamic_index_splitter.cc b/tensorflow/compiler/xla/service/dynamic_index_splitter.cc new file mode 100644 index 0000000000000000000000000000000000000000..e34adfd2d2bbb7214cfa2da28291b133538845e5 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_index_splitter.cc @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +StatusOr DynamicIndexSplitter::Run(HloModule* module) { + bool changed = false; + + std::vector computations = + module->MakeNonfusionComputations(); + for (HloComputation* computation : computations) { + for (HloInstruction* dynamic_op : computation->MakeInstructionPostOrder()) { + switch (dynamic_op->opcode()) { + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + break; + default: + continue; + } + auto parent = dynamic_op->parent(); + bool is_update = dynamic_op->opcode() == HloOpcode::kDynamicUpdateSlice; + int64 num_indices = dynamic_op->operand(0)->shape().rank(); + + if (num_indices == 0) { + // If the operand rank is 0, directly replace R0 DS/DUS with the + // operand (for DS) or update (for DUS). + if (is_update) { + TF_CHECK_OK(parent->ReplaceInstruction( + dynamic_op, dynamic_op->mutable_operand(1))); + } else { + TF_CHECK_OK(parent->ReplaceInstruction( + dynamic_op, dynamic_op->mutable_operand(0))); + } + changed = true; + continue; + } + + int64 index_operand_number = Cast(dynamic_op) + ->first_index_operand_number(); + auto index_operand = dynamic_op->mutable_operand(index_operand_number); + if (ShapeUtil::IsScalar(index_operand->shape())) { + // This DS/DUS already uses scalar indices. + continue; + } + TF_RET_CHECK(index_operand->shape().rank() == 1); + auto index_element_type = index_operand->shape().element_type(); + std::vector index_array; + for (int64 dim = 0; dim < num_indices; ++dim) { + auto slice = parent->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(index_element_type, {1}), index_operand, {dim}, + {dim + 1}, {1})); + auto bitcast = parent->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(index_element_type, {}), slice)); + index_array.push_back(bitcast); + } + auto new_dynamic_op = + is_update + ? HloInstruction::CreateDynamicUpdateSlice( + dynamic_op->shape(), dynamic_op->mutable_operand(0), + dynamic_op->mutable_operand(1), absl::MakeSpan(index_array)) + : HloInstruction::CreateDynamicSlice( + dynamic_op->shape(), dynamic_op->mutable_operand(0), + absl::MakeSpan(index_array), + dynamic_op->dynamic_slice_sizes()); + TF_CHECK_OK(parent->ReplaceWithNewInstruction(dynamic_op, + std::move(new_dynamic_op))); + changed = true; + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_index_splitter.h b/tensorflow/compiler/xla/service/dynamic_index_splitter.h new file mode 100644 index 0000000000000000000000000000000000000000..3c12e3a4af287ad2272a08ba54cd99c2cad9d451 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_index_splitter.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_INDEX_SPLITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_INDEX_SPLITTER_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Convert R1 index operands to DynamicSlice and DynamicUpdateSlice ops into +// separate scalars. +class DynamicIndexSplitter : public HloModulePass { + public: + DynamicIndexSplitter() = default; + absl::string_view name() const override { return "dynamic-index-splitter"; } + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_INDEX_SPLITTER_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_index_splitter_test.cc b/tensorflow/compiler/xla/service/dynamic_index_splitter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..98029d1faff7d669730f6b66e38fcefece70f0eb --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_index_splitter_test.cc @@ -0,0 +1,134 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; +class DynamicIndexSplitterTest : public HloTestBase {}; + +TEST_F(DynamicIndexSplitterTest, DynamicSlice) { + const char* const kDynamicSlice = R"( + HloModule DynamicSlice_module + + ENTRY entry (operand: s32[4,5,6], indices: s32[3]) -> s32[1,1,1] { + operand = s32[4,5,6] parameter(0) + indices = s32[3] parameter(1) + ROOT dynamic-slice = s32[1,1,1] dynamic-slice(operand, indices), dynamic_slice_sizes={1,1,1} + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kDynamicSlice, config)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + DynamicIndexSplitter().Run(module.get())); + EXPECT_TRUE(changed); + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::DynamicSlice(op::Parameter(0), + op::Reshape(op::Slice(op::Parameter(1))), + op::Reshape(op::Slice(op::Parameter(1))), + op::Reshape(op::Slice(op::Parameter(1))))); + + for (int i = 0; i < 3; ++i) { + const HloInstruction* slice = module->entry_computation() + ->root_instruction() + ->operand(i + 1) + ->operand(0); + EXPECT_EQ(slice->slice_starts(0), i); + EXPECT_EQ(slice->slice_limits(0), i + 1); + } +} + +TEST_F(DynamicIndexSplitterTest, DynamicUpdateSlice) { + const char* const kDynamicUpdateSlice = R"( + HloModule DynamicUpdatedSlice_module + + ENTRY entry (operand: s32[4,5,6], indices: s32[3], update: s32[1,1,1]) -> s32[4,5,6] { + operand = s32[4,5,6] parameter(0) + indices = s32[3] parameter(1) + update = s32[1,1,1] parameter(2) + ROOT dynamic-update-slice = s32[4,5,6] dynamic-update-slice(operand, update, indices) + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kDynamicUpdateSlice, config)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + DynamicIndexSplitter().Run(module.get())); + EXPECT_TRUE(changed); + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::DynamicUpdateSlice(op::Parameter(0), op::Parameter(2), + op::Reshape(op::Slice(op::Parameter(1))), + op::Reshape(op::Slice(op::Parameter(1))), + op::Reshape(op::Slice(op::Parameter(1))))); + + for (int i = 0; i < 3; ++i) { + const HloInstruction* slice = module->entry_computation() + ->root_instruction() + ->operand(i + 2) + ->operand(0); + EXPECT_EQ(slice->slice_starts(0), i); + EXPECT_EQ(slice->slice_limits(0), i + 1); + } +} + +TEST_F(DynamicIndexSplitterTest, AlreadyScalar) { + const char* const kDynamicSlice = R"( + HloModule DynamicSlice_module + + ENTRY entry (operand: s32[4,5,6], index.0: s32[], index.1: s32[], index.2: s32[]) -> s32[1,1,1] { + operand = s32[4,5,6] parameter(0) + index.0 = s32[] parameter(1) + index.1 = s32[] parameter(2) + index.2 = s32[] parameter(3) + ROOT dynamic-slice = s32[1,1,1] dynamic-slice(operand, index.0, index.1, index.2), dynamic_slice_sizes={1,1,1} + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kDynamicSlice, config)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + DynamicIndexSplitter().Run(module.get())); + EXPECT_FALSE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::DynamicSlice(op::Parameter(0), op::Parameter(1), + op::Parameter(2), op::Parameter(3))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc new file mode 100644 index 0000000000000000000000000000000000000000..4db280f817141bd52e3a5b9564600a618f81aeac --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -0,0 +1,161 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/dynamic_padder.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" + +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +namespace { + +// ChooseIdentityValue looks at the instruction and returns a identity value +// which, when padded, doesn't change the result of the instruction. +// +// nullopt is returned if padding doesn't need to be reset. +StatusOr ChooseIdentityValue(HloInstruction* inst) { + HloComputation* comp = inst->parent(); + // Padding on elementwise operation doesn't affect the result of the effective + // data. + if (inst->IsElementwise()) { + return nullptr; + } + + switch (inst->opcode()) { + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: { + // Because of the way we do reduce, we already require the `init` operand + // of hlo reduce instruction to be identity value. Here we reuse the + // operand. + return inst->mutable_operand(1); + } + + case HloOpcode::kConvolution: + case HloOpcode::kDot: { + // Use 0 as padding value for convolution and dot. + PrimitiveType ptype = inst->shape().element_type(); + return comp->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(ptype))); + } + + case HloOpcode::kPad: { + return inst->mutable_operand(1); + } + case HloOpcode::kParameter: + case HloOpcode::kGetDimensionSize: + case HloOpcode::kReshape: + case HloOpcode::kTuple: + case HloOpcode::kAllReduce: + case HloOpcode::kBroadcast: + return nullptr; + default: + return UnimplementedStrCat("Unimplimented padding for instruction: ", + inst->ToString()); + } +} + +} // namespace + +StatusOr DynamicPadder::Run(HloModule* module) { + bool changed = false; + VLOG(2) << "Pre DynamicPadder HLO:"; + XLA_VLOG_LINES(2, module->ToString()); + TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference, + DynamicDimensionInference::Run(module)); + + for (HloComputation* computation : module->computations()) { + for (HloInstruction* inst : computation->instructions()) { + for (int64 operand_num = 0; operand_num < inst->operand_count(); + ++operand_num) { + HloInstruction* operand = inst->mutable_operand(operand_num); + if (!operand->shape().IsArray()) { + continue; + } + for (int64 dim = 0; dim < operand->shape().rank(); ++dim) { + HloInstruction* dynamic_size = + dynamic_dimension_inference.GetDynamicSize(operand, {}, dim); + if (dynamic_size == nullptr) { + continue; + } + VLOG(1) << "Has dynamic dimension of operand" << operand_num << " @" + << dim; + TF_ASSIGN_OR_RETURN(HloInstruction * identity_value, + ChooseIdentityValue(inst)); + if (identity_value == nullptr) { + continue; + } + + // For each dimension, first generates a mask representing the + // effective area of data and padded area of data using iota and + // dynamic_size. For example, given a dimension of 7 elements and 5 + // effective elements: + // + // iota = [0, 1, 2, 3, 4, 5, 6] + // broadcast_dynamic_size = [5, 5, 5, 5, 5, 5, 5] + // mask = lt(iota, broadcast_dynamic_size) = [t, t, t, t, t, f, f] + // + // Once the mask is generated, the input data is then padded using the + // mask and pad value. + // + const Shape mask_shape = + ShapeUtil::ChangeElementType(operand->shape(), xla::U32); + const Shape pred_shape = + ShapeUtil::ChangeElementType(operand->shape(), xla::PRED); + HloInstruction* iota = computation->AddInstruction( + HloInstruction::CreateIota(mask_shape, dim)); + + HloInstruction* broadcasted_effective_size = + computation->AddInstruction(HloInstruction::CreateBroadcast( + mask_shape, dynamic_size, {})); + HloInstruction* pred = computation->AddInstruction( + HloInstruction::CreateBinary(pred_shape, HloOpcode::kLt, iota, + broadcasted_effective_size)); + + HloInstruction* broadcasted_identity_value = + computation->AddInstruction(HloInstruction::CreateBroadcast( + operand->shape(), identity_value, {})); + HloInstruction* padded = + computation->AddInstruction(HloInstruction::CreateTernary( + operand->shape(), HloOpcode::kSelect, pred, operand, + broadcasted_identity_value)); + TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(operand_num, padded)); + operand = inst->mutable_operand(operand_num); + changed = true; + } + } + } + } + HloDCE dce; + TF_ASSIGN_OR_RETURN(changed, dce.Run(module)); + VLOG(2) << "Post DynamicPadder HLO:"; + XLA_VLOG_LINES(2, module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_padder.h b/tensorflow/compiler/xla/service/dynamic_padder.h new file mode 100644 index 0000000000000000000000000000000000000000..509269f7f56746fa5516ad917a04221587c6dcca --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_padder.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PADDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PADDER_H_ + +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// With bounded shapes, only part of the shape contains effective data and the +// rest contains padded data, whose value can be anything depending on the +// source of the data. When a bounded shape is directly consumed by an +// instruction that collapses dimensions (reduce for example), the padding data +// would affect result of the instruction. +// +// DynamicPadder uses DynamicDimensionInference to detect bounded shapes in a +// hlo module, it then inserts certain instructions to reset the padding into an +// identity value so that in doesn't affect the result of subsequent +// instruction. For example, it'd reset the padding to 0 before a bounded shape +// is consumed by a reduce-sum. +class DynamicPadder : public HloModulePass { + public: + absl::string_view name() const override { return "dynamic_padder"; } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PADDER_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..55a11286e4596d87c330315322cae704fc5cd707 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -0,0 +1,152 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_padder.h" + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class DynamicPadderTest : public HloTestBase { + protected: + DynamicPadderTest() : HloTestBase() { module_ = CreateNewVerifiedModule(); } + + StatusOr RunPadder() { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before padder"); + + DynamicPadder padder; + + return padder.Run(module_.get()); + } + + void ExpectPadded(const HloInstruction* inst) { + EXPECT_THAT(inst, + op::Select(op::Lt(op::Iota(), op::Broadcast(op::Parameter())), + ::testing::_, op::Broadcast())); + } + + HloComputation* GetScalarAddComputation() { + auto embedded_builder = HloComputation::Builder("add"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + return module_->AddEmbeddedComputation(embedded_builder.Build()); + } + + std::unique_ptr module_; + const Shape scalar_shape_ = ShapeUtil::MakeShape(U32, {}); +}; + +TEST_F(DynamicPadderTest, ReduceTest) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {2}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param)); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( + reduce_shape, negate, init, {0, 2}, GetScalarAddComputation())); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunPadder().status()); + + ExpectPadded(reduce->operand(0)); +} + +TEST_F(DynamicPadderTest, ConvolutionTest) { + auto builder = HloComputation::Builder(TestName()); + constexpr int xdim = 3; + constexpr int ydim = 2; + constexpr int zdim = 1; + auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim}); + auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim}); + auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, xy_shape, "A")); + auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, yz_shape, "B")); + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0); + + dnums.set_kernel_input_feature_dimension(0); + dnums.set_kernel_output_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(1); + dnums.set_output_feature_dimension(0); + + Window window; + + auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + zx_shape, a_param, b_param, /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums, + HloTestBase::DefaultPrecisionConfig(2))); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding for non-contracting dimension. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + // Set up binding for contracting dimensions. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunPadder().status()); + + ExpectPadded(conv->operand(0)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc b/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc index c8bfc8905064bcd7b68fe259fbcc1546ff083dbd..e9c8aa03e2aa8f4866daf2a2f8d846e50fa68793 100644 --- a/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc @@ -29,7 +29,8 @@ Status DynamicParameterBinding::Bind( } absl::optional -DynamicParameterBinding::GetBinding(const DynamicDimension& dynamic_dimension) { +DynamicParameterBinding::GetBinding( + const DynamicDimension& dynamic_dimension) const { auto param_iter = bindings_.find(dynamic_dimension); if (param_iter == bindings_.end()) { return absl::nullopt; @@ -70,7 +71,7 @@ StatusOr DynamicParameterBinding::CreateFromProto( int64 target_param_num = binding.target_param_num(); ShapeIndex target_param_index(binding.target_param_index().begin(), binding.target_param_index().end()); - int64 target_dim_num = binding.target_param_num(); + int64 target_dim_num = binding.target_param_dim_num(); TF_RETURN_IF_ERROR( result.Bind(DynamicParameter{dynamic_param_num, dynamic_param_index}, @@ -121,10 +122,11 @@ Status DynamicParameterBinding::Verify(const HloModule& module) const { dynamic_dimension.parameter_index)); TF_RET_CHECK( dynamic_dimension.dimension < - ShapeUtil::Rank(ShapeUtil::GetSubshape( + ShapeUtil::GetSubshape( entry->parameter_instruction(dynamic_dimension.parameter_num) ->shape(), - dynamic_dimension.parameter_index))); + dynamic_dimension.parameter_index) + .rank()); return Status::OK(); }); } diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding.h b/tensorflow/compiler/xla/service/dynamic_parameter_binding.h index dd474d8eed1b2c30ddb8f624a864198c74eacaba..57af2c43d3c65f7340e6a9f04e5abbf052ebceea 100644 --- a/tensorflow/compiler/xla/service/dynamic_parameter_binding.h +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding.h @@ -89,7 +89,7 @@ class DynamicParameterBinding { // // Returns nullopt if the binding is not set. absl::optional GetBinding( - const DynamicDimension& dynamic_dimension); + const DynamicDimension& dynamic_dimension) const; using BindingFn = std::functionToProto(); + TF_ASSERT_OK_AND_ASSIGN(*binding, + DynamicParameterBinding::CreateFromProto(proto)); + } +}; TEST_F(DynamicParameterBindingTest, SimpleBinding) { // 'b' is a dynamic shape; 'a' represents the real size of b's first @@ -56,15 +64,20 @@ ENTRY main { binding.Bind(DynamicParameterBinding::DynamicParameter{0, {}}, DynamicParameterBinding::DynamicDimension{1, {}, 0})); - absl::optional param = - binding.GetBinding( - DynamicParameterBinding::DynamicDimension{/*parameter_num=*/1, - /*parameter_index=*/{}, - /*dimension=*/0}); - EXPECT_TRUE(param); - EXPECT_EQ(param->parameter_num, 0); - EXPECT_EQ(param->parameter_index, ShapeIndex({})); - TF_EXPECT_OK(binding.Verify(*module)); + auto test = [&](const DynamicParameterBinding& binding) { + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/1, + /*parameter_index=*/{}, + /*dimension=*/0}); + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({})); + TF_EXPECT_OK(binding.Verify(*module)); + }; + test(binding); + SerializeAndDeserialize(&binding); + test(binding); } TEST_F(DynamicParameterBindingTest, TupleBinding) { @@ -89,16 +102,21 @@ ENTRY main { binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, DynamicParameterBinding::DynamicDimension{0, {1}, 0})); - absl::optional param = - binding.GetBinding( - DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, - /*parameter_index=*/{1}, - /*dimension=*/0}); - - EXPECT_TRUE(param); - EXPECT_EQ(param->parameter_num, 0); - EXPECT_EQ(param->parameter_index, ShapeIndex({0})); - TF_EXPECT_OK(binding.Verify(*module)); + auto test = [&](const DynamicParameterBinding& binding) { + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({0})); + TF_EXPECT_OK(binding.Verify(*module)); + }; + test(binding); + SerializeAndDeserialize(&binding); + test(binding); } TEST_F(DynamicParameterBindingTest, TupleBindingWithMultiDimension) { @@ -127,26 +145,35 @@ ENTRY main { binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, DynamicParameterBinding::DynamicDimension{0, {1}, 1})); - absl::optional param = - binding.GetBinding( - DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, - /*parameter_index=*/{1}, - /*dimension=*/0}); - - EXPECT_TRUE(param); - EXPECT_EQ(param->parameter_num, 0); - EXPECT_EQ(param->parameter_index, ShapeIndex({0})); - - absl::optional param2 = - binding.GetBinding( - DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, - /*parameter_index=*/{1}, - /*dimension=*/0}); - EXPECT_TRUE(param2); - EXPECT_EQ(param2->parameter_num, 0); - EXPECT_EQ(param2->parameter_index, ShapeIndex({0})); - - TF_EXPECT_OK(binding.Verify(*module)); + auto test = [&](const DynamicParameterBinding& binding) { + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({0})); + + absl::optional param2 = + + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + EXPECT_TRUE(param2); + EXPECT_EQ(param2->parameter_num, 0); + EXPECT_EQ(param2->parameter_index, ShapeIndex({0})); + TF_EXPECT_OK(binding.Verify(*module)); + }; + + test(binding); + + SerializeAndDeserialize(&binding); + + // Test the binding again after deserialization. + test(binding); } } // namespace diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 6f1f95f2e9082649b6ca9cc0da5c238e15b77c10..727e0bfa52d45b6f8c67d7d04613e4865f18a53c 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -812,11 +812,14 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( auto c = EmitExtractReal(rhs_value); auto d = EmitExtractImag(rhs_value); auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b)); + auto zero = llvm::ConstantFP::get(a->getType(), 0); auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); + auto one = llvm::ConstantFP::get(a->getType(), 1); auto half_c = FMul(one_half, c); TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c, EmitPow(component_type, aa_p_bb, half_c)); + auto neg_d = FNeg(d); TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a)); auto neg_d_arg_lhs = FMul(neg_d, arg_lhs); @@ -828,7 +831,13 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb)); TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); - return EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q)); + // 0^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see + // Branch Cuts for Complex Elementary Functions or Much Ado About + // Nothing's Sign Bit, W. Kahan, Section 10. + return Select( + And(And(FCmpOEQ(aa_p_bb, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)), + EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero), + EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q))); } default: return Unimplemented("binary complex op '%s'", @@ -1327,9 +1336,9 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( // If implicit broadcast is needed, the source dimensions that are broadcast // have index 0. - CHECK_EQ(ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(hlo.shape())); + CHECK_EQ(operand_shape.rank(), hlo.shape().rank()); llvm_ir::IrArray::Index source_index(target_index.GetType()); - for (int64 i = 0; i < ShapeUtil::Rank(hlo.shape()); ++i) { + for (int64 i = 0; i < hlo.shape().rank(); ++i) { if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) { source_index.push_back(target_index[i]); } else { @@ -1750,7 +1759,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( const llvm_ir::IrArray::Index& index) { // Emit IR to read dynamic start indices from hlo->operand(1). const HloInstruction* input_hlo = hlo->operand(0); - const int64 rank = ShapeUtil::Rank(input_hlo->shape()); + const int64 rank = input_hlo->shape().rank(); // Use the same index type for all tensor accesses in the same kernel. llvm::Type* index_type = index.GetType(); llvm_ir::IrArray::Index slice_start_index(index_type, rank); @@ -1758,9 +1767,18 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( auto index_typed_const = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_type, c); }; - llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); - TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, - operand_to_generator.at(hlo->operand(1))(dim_index)); + // TODO(b/118437727): Remove the R1 path. + llvm::Value* start_index_value; + if (hlo->operand(1)->shape().rank() == 1) { + llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); + TF_ASSIGN_OR_RETURN(start_index_value, + operand_to_generator.at(hlo->operand(1))(dim_index)); + } else { + llvm_ir::IrArray::Index zero_index(index_type); + TF_ASSIGN_OR_RETURN( + start_index_value, + operand_to_generator.at(hlo->operand(1 + i))(zero_index)); + } // Clamp the start index so that the sliced portion fits in the operand: // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) @@ -1893,7 +1911,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( const HloInstruction* update_hlo = hlo->operand(1); const HloInstruction* start_hlo = hlo->operand(2); // Calculate slice start/end indices. - const int64 rank = ShapeUtil::Rank(input_hlo->shape()); + const int64 rank = input_hlo->shape().rank(); llvm_ir::IrArray::Index slice_start_index(index.GetType(), rank); llvm_ir::IrArray::Index slice_limit_index(index.GetType(), rank); // Slice intersection gathers (ANDs) conditions on all ranks for which @@ -1905,9 +1923,19 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( auto index_typed_const = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_type, c); }; - llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); - TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, - operand_to_generator.at(start_hlo)(dim_index)); + + llvm::Value* start_index_value; + // TODO(b/118437727): Remove the R1 path. + if (hlo->operand(2)->shape().rank() == 1) { + llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); + TF_ASSIGN_OR_RETURN(start_index_value, + operand_to_generator.at(hlo->operand(2))(dim_index)); + } else { + llvm_ir::IrArray::Index zero_index(index_type); + TF_ASSIGN_OR_RETURN( + start_index_value, + operand_to_generator.at(hlo->operand(2 + i))(zero_index)); + } // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) @@ -2225,7 +2253,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( auto* iota = Cast(hlo); PrimitiveType element_type = iota->shape().element_type(); IrArray::Index elem_index = - ShapeUtil::Rank(iota->shape()) > 1 + iota->shape().rank() > 1 ? target_index.SourceIndexOfBroadcast( iota->shape(), ShapeUtil::MakeShapeWithDescendingLayout( diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 01cef499665c050d4453382289168276028e1d26..590942cddcdd138981ee829f090ae17b0d038e1a 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -153,10 +153,9 @@ static StatusOr> GatherLoopBody( dim_numbers.index_vector_dim() == gather.operand(1)->shape().dimensions_size()); - TF_ASSIGN_OR_RETURN( - HloInstruction * induction_var_as_vector, + HloInstruction* induction_var_as_vector = MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, - /*result_shape_bounds=*/{1})); + /*result_shape_bounds=*/{1}); HloInstruction* index_vector; @@ -222,7 +221,7 @@ static StatusOr> GatherLoopBody( {operand, start_indices, updated_accumulator}}; } -static StatusOr CreateGatherLoopAccumulatorInitValue( +static HloInstruction* CreateGatherLoopAccumulatorInitValue( HloComputation* computation, PrimitiveType element_type, absl::Span slice_sizes, int64 gather_loop_trip_count, const GatherDimensionNumbers& dim_numbers) { @@ -332,12 +331,10 @@ StatusOr GatherExpander::ExpandGather( CHECK_EQ(gather_loop_trip_count, canonical_start_indices->shape().dimensions(0)); - TF_ASSIGN_OR_RETURN( - HloInstruction * accumulator_init, - CreateGatherLoopAccumulatorInitValue( - computation, output_shape.element_type(), - gather_instr->gather_slice_sizes(), gather_loop_trip_count, - gather_instr->gather_dimension_numbers())); + HloInstruction* accumulator_init = CreateGatherLoopAccumulatorInitValue( + computation, output_shape.element_type(), + gather_instr->gather_slice_sizes(), gather_loop_trip_count, + gather_instr->gather_dimension_numbers()); StatusOr> gather_loop_result_or_error = WhileUtil::MakeCountedLoop( diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc index a3102368cb1dba15da7422337666d278cef775ab..e1ea5c39d58b6d23b076740626ca0ad63dc341ee 100644 --- a/tensorflow/compiler/xla/service/gather_expander_test.cc +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -89,7 +89,7 @@ ENTRY main { // an implementation detail from WhileUtil::MakeCountedLoop). const Shape& while_shape = while_instr->shape(); - ASSERT_TRUE(ShapeUtil::IsTuple(while_shape)); + ASSERT_TRUE(while_shape.IsTuple()); ASSERT_EQ(ShapeUtil::TupleElementCount(while_shape), 4); EXPECT_TRUE(ShapeUtil::SameDimensions( diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index bec02e14f951c6d905b7329be5c02896984279d0..7d450f4b53cdea209f2ef10ba785be6ec3b8bf8d 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -83,7 +83,7 @@ Status GenericTransferManager::TransferLiteralFromDeviceInternal( TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_host_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> Status { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { TF_RETURN_IF_ERROR(executor->SynchronousMemcpyD2H( /*source=*/device_buffer.buffer(index), /*size=*/GetByteSizeRequirement(subshape), @@ -120,7 +120,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( device_buffer.on_host_shape(), [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { se::DeviceMemoryBase device_memory = device_buffer.buffer(index); - if (ShapeUtil::IsArray(device_subshape)) { + if (device_subshape.IsArray()) { TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == device_memory.size()); // Element is array-shaped: transfer array data to device buffer. diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 6c23f921f40cac0dc5df08494dc1b63e6d1d5e93..dc17aa4426236f54e5f03c28634278d45f462158 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -3,6 +3,11 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) +load("//tensorflow:tensorflow.bzl", "tf_cc_test") licenses(["notice"]) # Apache 2.0 @@ -24,12 +29,6 @@ filegroup( ]), ) -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load( - "//tensorflow/core:platform/default/build_config_root.bzl", - "tf_cuda_tests_tags", -) - xla_proto_library( name = "backend_configs", srcs = ["backend_configs.proto"], @@ -94,8 +93,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_reachability", - "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) @@ -135,6 +134,8 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm//:core", @@ -263,7 +264,9 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -362,6 +365,7 @@ cc_library( "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep "//tensorflow/stream_executor", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -695,6 +699,8 @@ cc_library( "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_group_converter", + "//tensorflow/compiler/xla/service:dot_decomposer", + "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -712,6 +718,7 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:sort_simplifier", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", @@ -725,6 +732,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor/cuda:cuda_diagnostics", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1005,14 +1013,10 @@ cc_library( srcs = ["variadic_op_splitter.cc"], hdrs = ["variadic_op_splitter.h"], deps = [ - ":ir_emission_utils", - "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index 528209abc75777440163c2e1512658b8ad36315b..eb59ee5a1d47b6b706ef3f53a76069b3538eb6b7 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -57,16 +58,16 @@ StatusOr> BufferAllocations::Builder::Build( // If buffer #i's address is already registered (e.g. external arguments or // result buffers), use that registered buffer. - if (registered_buffers_.count(i)) { - se::DeviceMemoryBase address = FindOrDie(registered_buffers_, i); - if (reinterpret_cast(address.opaque()) % expected_alignment != + if (se::DeviceMemoryBase* address = + tensorflow::gtl::FindOrNull(registered_buffers_, i)) { + if (reinterpret_cast(address->opaque()) % expected_alignment != 0) { return InternalError( "Address of registered buffer %d must be a multiple of %x, but " "was %p", - i, kEntryParameterAlignBytes, address.opaque()); + i, kEntryParameterAlignBytes, address->opaque()); } - buffer_allocations->SetBuffer(i, FindOrDie(registered_buffers_, i)); + buffer_allocations->SetBuffer(i, *address); continue; } diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index 14186b8faa68ad8492ea4863fcd7bd746e2eae48..9413ac2cff7c8d3ec4be6662569c580060bf1173 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -52,7 +53,8 @@ class BufferAllocations { DeviceMemoryAllocator* memory_allocator); private: - std::map registered_buffers_; + absl::flat_hash_map + registered_buffers_; }; ~BufferAllocations(); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc index 6d6780fa1c7b0c636eb771c40e74f074cd8c4c4b..309b0aca64954e64509d731dce28ce9d8da4ee43 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc @@ -146,7 +146,8 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { // trouble, but we may want to revisit this if we ever find a model where // caching would speed up compilation a lot. StatusOr -CudnnConvAlgorithmPicker::PickBestAlgorithm(HloCustomCallInstruction* instr) { +CudnnConvAlgorithmPicker::PickBestAlgorithm( + const HloCustomCallInstruction* instr) { // TODO(timshen): for now only check fp16. It can be expanded to other types, // with some work on the HLO routines. const bool cross_check_enabled = @@ -249,12 +250,13 @@ CudnnConvAlgorithmPicker::PickBestAlgorithm(HloCustomCallInstruction* instr) { VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); - backend_config.set_algorithm(alg.algo_id()); - backend_config.set_tensor_ops_enabled(alg.tensor_ops_enabled()); - TF_RETURN_IF_ERROR(instr->set_backend_config(backend_config)); + // Use assignment instead of brace-list to make GCC 4.9 happy. + RunConvOptions options; + options.profile_result = &profile_result; + options.algo_override = alg; bool launch_ok = RunCudnnConv(instr, absl::MakeSpan(operand_buffers), result_buffer, - &scratch_allocator, &stream, &profile_result) + &scratch_allocator, &stream, options) .ok(); if (launch_ok && profile_result.is_valid()) { diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h index 642af787afc71586d722ecc7e529ed8b3fa64d33..4991db0948589e479a202f4082d96df275f6e088 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h @@ -56,7 +56,8 @@ class CudnnConvAlgorithmPicker : public HloModulePass { StatusOr RunOnComputation(HloComputation* computation); StatusOr RunOnInstruction(HloInstruction* instr); - StatusOr PickBestAlgorithm(HloCustomCallInstruction* instr); + StatusOr PickBestAlgorithm( + const HloCustomCallInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc index 5aa4f839f4be5f1060480fea98775f8ffada0bdd..958e0b9c6e7b7885f87b90d61ee5b3bbf6ab2702 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc @@ -50,10 +50,10 @@ static HloInstruction* PadInstruction(HloInstruction* instr, auto* zero = comp->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); - PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape)); + PaddingConfig pad_config = MakeNoPaddingConfig(shape.rank()); bool added_padding = false; - for (int64 dim = 0; dim < ShapeUtil::Rank(shape); ++dim) { + for (int64 dim = 0; dim < shape.rank(); ++dim) { if (shape.dimensions(dim) == new_shape.dimensions(dim)) { continue; } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc index 3a09d4d4716950a09d65dd093272482d55ac5c27..17d0f7aa7bf6031148aae79f74f7878d6fca9574 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc @@ -219,7 +219,7 @@ bool CudnnConvPaddingLegalization::CanonicalizeBackwardFilterConvolution( Window new_backward_conv_window = backward_conv->window(); // input_padding_config is the config of the kPad to be inserted. PaddingConfig input_padding_config = - MakeNoPaddingConfig(ShapeUtil::Rank(input->shape())); + MakeNoPaddingConfig(input->shape().rank()); ConvolutionDimensionNumbers backward_conv_dnums = backward_conv->convolution_dimension_numbers(); for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc index 3425e1b4942aaf1011ba1bf1c50dd7e79c1f9807..b628f27f4b2ba8ccf17fd531d8a0c25cb99d9396 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc @@ -395,32 +395,36 @@ Status RunCudnnConv(const HloCustomCallInstruction* conv, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::DeviceMemoryBase scratch_buf, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { + RunConvOptions options) { ScratchBufAllocator scratch_allocator(scratch_buf); return RunCudnnConv(conv, operand_buffers, result_buffer, &scratch_allocator, - stream, profile_result); + stream, options); } Status RunCudnnConv(const HloCustomCallInstruction* conv, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::ScratchAllocator* scratch_allocator, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { + RunConvOptions options) { TF_ASSIGN_OR_RETURN(CudnnConvParams params, GetCudnnConvParams(conv, operand_buffers, result_buffer)); + if (options.algo_override) { + params.algorithm = AlgorithmConfig(*options.algo_override); + } + PrimitiveType output_primitive_type = conv->shape().tuple_shapes(0).element_type(); switch (output_primitive_type) { case F16: return RunCudnnConvImpl(params, scratch_allocator, stream, - profile_result); + options.profile_result); case F32: return RunCudnnConvImpl(params, scratch_allocator, stream, - profile_result); + options.profile_result); case F64: return RunCudnnConvImpl(params, scratch_allocator, stream, - profile_result); + options.profile_result); default: LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape); } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h index edbc75a94a1238540390b93f0fa5217852c7781f..25b2461ca61251c6cb7b89b1f91da0f1636a3647 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h @@ -28,6 +28,14 @@ limitations under the License. namespace xla { namespace gpu { +struct RunConvOptions { + // Nullable output-parameter pointer for profiling results. + se::dnn::ProfileResult* profile_result = nullptr; + + // Use this algorithm, instead of the one from the instruction. + absl::optional algo_override; +}; + // This file contains low-level routines for running cudnn convolutions. // Calls into cudnn to run the specified convolution. @@ -46,13 +54,13 @@ Status RunCudnnConv(const HloCustomCallInstruction* conv, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::DeviceMemoryBase scratch_buf, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); + RunConvOptions = {}); Status RunCudnnConv(const HloCustomCallInstruction* conv, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::ScratchAllocator* scratch_allocator, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); + RunConvOptions = {}); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 470457935acacb8940af241dadb393d770786939..91930eccdff94bb2fc85636f3a4b2d661c618d87 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -35,7 +35,7 @@ namespace { // Traverses users of tuple shape, adding leaf instructions to 'instructions'. void MaybeResolveTupleElements(HloInstruction* instruction, std::vector* instructions) { - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { for (auto tuple_user : instruction->users()) { MaybeResolveTupleElements(tuple_user, instructions); } diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 27f07b1d58125092c1ed6734b238e4ae0f11c4aa..86c9bc6a345047fb5329af0be45c8981cc427f50 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -206,6 +206,8 @@ auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm) { return &DoGemm; case C64: return &DoGemm>; + case C128: + return &DoGemm>; default: LOG(FATAL) << "Unsupported type."; } @@ -221,6 +223,8 @@ auto GetGemmWithAlgorithmFn(PrimitiveType type) return &DoGemmWithAlgorithm; case C64: return &DoGemmWithAlgorithm>; + case C128: + return &DoGemmWithAlgorithm>; default: LOG(FATAL) << "Unsupported type."; } @@ -235,6 +239,8 @@ auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune) { return &DoGemmAutotune; case C64: return &DoGemmAutotune>; + case C128: + return &DoGemmAutotune>; default: LOG(FATAL) << "Unsupported type."; } @@ -255,6 +261,8 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { return se::blas::ComputationType::kF64; case C64: return se::blas::ComputationType::kComplexF32; + case C128: + return se::blas::ComputationType::kComplexF64; default: LOG(FATAL) << "Unsupported type."; } @@ -315,8 +323,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), dim_nums.rhs_batch_dimensions_size()); - CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, - ShapeUtil::Rank(output_shape_)); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape_.rank()); int64 row_dim = dim_nums.lhs_batch_dimensions_size(); int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index ae2e718db29803a085401969a7d9b09abf690a6c..434060ad89dac7ad65c790c8c0a7f3d6ad62a25a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -218,7 +218,7 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation); - CHECK(ShapeUtil::IsArray(literal.shape())); + CHECK(literal.shape().IsArray()); if (!ShouldEmitLiteralInLlvmIr(literal)) { VLOG(3) << "H2D memcpy for constant with shape " << ShapeUtil::HumanString(literal.shape()); @@ -310,12 +310,34 @@ StatusOr GpuExecutable::ExecuteOnStream( TF_ASSIGN_OR_RETURN( const BufferAllocation::Slice slice, this->assignment_->GetUniqueSlice(src_hlo, sources[0]->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); se::DeviceMemoryBase src_base = buffer_allocations->GetDeviceAddress(slice.index()); CHECK(!src_base.is_null() || src_base.size() == 0); - *device_memory = src_base; + if (!slice.allocation()->is_entry_computation_parameter()) { + // If the buffer coming out of the result is from a parameter, it + // means the caller aliased some parameter buffer to an output one + // (via the HloInputOutputAliasConfig API). If that is the case, the + // caller will receive a partially complete scoped shaped buffer, + // which they will have to fill up on return. + // Unfortunately the interface to the execute APIs are ShapedBuffer + // pointer based, which assumes caller ownership, and hence a buffer + // coming from there cannot be part of the new ScopedShapedBuffer we + // create for the result (which assumes ownership). + *device_memory = src_base; + } else { + const HloInputOutputAliasConfig& input_output_alias = + module().input_output_alias_config(); + auto output_alias = input_output_alias.GetAliasedOutput( + slice.allocation()->parameter_number(), + slice.allocation()->param_shape_index()); + CHECK(output_alias) + << "Ouput buffer is coming from parameter " + << slice.allocation()->parameter_number() << " at index " + << slice.allocation()->param_shape_index() + << ", but no alias exists"; + CHECK_EQ(*output_alias, index); + } buffers_in_result.insert(src_base); return Status::OK(); })); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 452e763a8eaadc805cd3a3859a68e2a31598fd36..842ba2fdcd31a451cec1be543e102e0a46077f38 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -42,15 +42,13 @@ bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, int64 max_rank = -1; const Layout* max_rank_layout; for (HloInstruction* param : params) { - if (ShapeUtil::IsArray(param->shape()) && - ShapeUtil::Rank(param->shape()) > max_rank) { - max_rank = ShapeUtil::Rank(param->shape()); + if (param->shape().IsArray() && param->shape().rank() > max_rank) { + max_rank = param->shape().rank(); max_rank_layout = ¶m->shape().layout(); } } return absl::c_all_of(params, [&](HloInstruction* param) { - return (!ShapeUtil::IsArray(param->shape())) || - (ShapeUtil::Rank(param->shape()) < max_rank) || + return (!param->shape().IsArray()) || (param->shape().rank() < max_rank) || (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout)); }); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index f59da2caa18646676297e66dd329c66fb5fddf1b..58bdd4209a2315cdb7d29e920faded4d1a6a5876 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -196,9 +196,9 @@ Status GpuLayoutAssignment::AddBackendConstraints( CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), dim_nums.rhs_batch_dimensions_size()); CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, - ShapeUtil::Rank(instruction->shape())); + instruction->shape().rank()); for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) { - CHECK_LT(batch_dim, ShapeUtil::Rank(instruction->shape()) - 2); + CHECK_LT(batch_dim, instruction->shape().rank() - 2); } // Set both inputs and the output to default layout. @@ -215,18 +215,18 @@ Status GpuLayoutAssignment::AddBackendConstraints( TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(output_shape, instruction)); } else if (instruction->opcode() == HloOpcode::kSort && - ShapeUtil::Rank(instruction->operand(0)->shape()) > 1) { + instruction->operand(0)->shape().rank() > 1) { // Make sure that all the operands and the output(s) have the same layout. Shape keys_shape = instruction->operand(0)->shape(); Layout keys_layout = - LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(keys_shape)); + LayoutUtil::GetDefaultLayoutForRank(keys_shape.rank()); for (int64 i = 0; i < instruction->operand_count(); ++i) { Shape shape = instruction->operand(i)->shape(); *shape.mutable_layout() = keys_layout; TF_RETURN_IF_ERROR( constraints->SetOperandLayout(shape, instruction, i)); const LogicalBuffer* output_buffer; - if (ShapeUtil::IsArray(instruction->shape())) { + if (instruction->shape().IsArray()) { TF_ASSIGN_OR_RETURN( output_buffer, constraints->points_to_analysis().GetBufferDefinedAt(instruction, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index f3c274429242d5c989146d14ea523b5910408cff..8c6a6914792a96ab517fa5f20ff2215e4785490e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -59,7 +59,7 @@ Status GpuTransferManager::TransferLiteralToInfeed( TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( shape, [&](const Shape& literal_subshape, const ShapeIndex& index) { - if (ShapeUtil::IsArray(literal_subshape)) { + if (literal_subshape.IsArray()) { int64 tuple_element_size = GetByteSizeRequirement(literal_subshape); TF_ASSIGN_OR_RETURN( *buffer_tree.mutable_element(index), @@ -126,13 +126,12 @@ static void ShapeTreeToLiteral( ShapeTree>* shape_tree, ShapeIndex* index) { const Shape& shape = ShapeUtil::GetSubshape(shape_tree->shape(), *index); - if (ShapeUtil::IsArray(shape)) { + if (shape.IsArray()) { (*shape_tree->mutable_element(*index))->WaitUntilAvailable(); return; } - CHECK(ShapeUtil::IsTuple(shape)) - << ShapeUtil::HumanStringWithLayout(shape); + CHECK(shape.IsTuple()) << ShapeUtil::HumanStringWithLayout(shape); const int64 tuple_element_count = ShapeUtil::TupleElementCount(shape); index->push_back(0); for (int64 i = 0; i < tuple_element_count; ++i) { @@ -158,7 +157,7 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( std::unique_ptr* buffer) { const Shape& shape = ShapeUtil::GetSubshape(literal_shape, index); // Do not transfer tuple index buffers. - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { return; } *buffer = absl::make_unique( diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 51627402b45f594dab3480129ba182d54d01b811..69aaaceca112364a4fd562f6a5eff1629fd3fc54 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -45,10 +46,10 @@ void HloToIrBindings::EmitBasePointersForHlos( // An HLO can have duplicated operands. This data structure remembers which // operand HLOs are already bound to avoid rebinding the same HLO. - std::set already_bound_for_this_function; + absl::flat_hash_set already_bound_for_this_function; auto arg_iter = function->arg_begin(); for (const HloInstruction* io_hlo : io_hlos) { - if (!already_bound_for_this_function.count(io_hlo)) { + if (!already_bound_for_this_function.contains(io_hlo)) { if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) { BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter)); } else { @@ -63,7 +64,7 @@ void HloToIrBindings::EmitBasePointersForHlos( temp_buffer_base_->setName("temp_buffer"); for (const HloInstruction* non_io_hlo : non_io_hlos) { - if (already_bound_for_this_function.count(non_io_hlo)) { + if (already_bound_for_this_function.contains(non_io_hlo)) { continue; } already_bound_for_this_function.insert(non_io_hlo); @@ -280,7 +281,7 @@ string HloToIrBindings::ToString() const { StrAppend(&s, " ", instr->ToString()); const ShapeTree& shape_tree = it->second; - if (!ShapeUtil::IsTuple(instr->shape())) { + if (!instr->shape().IsTuple()) { const llvm::Value* val = shape_tree.begin()->second; StrAppend(&s, " -> ", llvm_ir::DumpToString(*val), "\n"); continue; diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index c0edae530cedba45c897b07b7b9cc72eaaab397c..f57b594e9c18078a3bbbf4d2b4db7e989c4edfdd 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" @@ -61,7 +62,7 @@ class HloToIrBindings { // Returns whether `hlo` is bound to an LLVM IR value. bool BoundToIrValue(const HloInstruction& hlo) const { - return base_ptrs_.count(&hlo); + return base_ptrs_.contains(&hlo); } llvm::Value* GetTempBufferBase() const { return temp_buffer_base_; } @@ -110,7 +111,8 @@ class HloToIrBindings { // For an instruction that generates multiple outputs, the root will be a // tuple shape. The IrArray for each element output is stored in the subnode // in the ShapeTree. - std::unordered_map> base_ptrs_; + absl::flat_hash_map> + base_ptrs_; // The address of the memory block that contains all temporary buffers. llvm::Value* temp_buffer_base_ = nullptr; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index 8c3a026740851767855beae59d6a3c92f7a0d6bd..676380c3b10f9a20c641eea0d9a948a26becaddc 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -36,6 +36,21 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, ShapeTree infeed_buffers = GetOrCreateInfeedManager()->BlockingGetNextDestination(); + // infeed_slices_'s shape should be a tuple of shape (buffers, token). + const auto& infeed_shape = infeed_slices_.shape(); + TF_RET_CHECK(infeed_shape.IsTuple()) + << ShapeUtil::HumanStringWithLayout(infeed_shape); + TF_RET_CHECK(infeed_shape.tuple_shapes().size() == 2) + << ShapeUtil::HumanStringWithLayout(infeed_shape); + TF_RET_CHECK(infeed_shape.tuple_shapes(1).IsToken()) + << ShapeUtil::HumanStringWithLayout(infeed_shape); + TF_RET_CHECK( + ShapeUtil::Equal(infeed_buffers.shape(), infeed_shape.tuple_shapes(0))) + << "Expected infeed of shape " + << ShapeUtil::HumanStringWithLayout(infeed_shape.tuple_shapes(0)) + << " but was " + << ShapeUtil::HumanStringWithLayout(infeed_buffers.shape()); + { // The infeed buffer has an extra outer tuple with a token. Adjust the index // accordingly. @@ -45,7 +60,7 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, const Shape& shape = ShapeUtil::GetSubshape(infeed_buffers.shape(), ShapeIndexView(index, 1)); // For the leaf buffers of the tuple copy the elements directly. - if (ShapeUtil::IsArray(shape)) { + if (shape.IsArray()) { const BufferAllocation::Slice& tuple_element_buffer = infeed_slices_.element(index); se::DeviceMemoryBase tuple_element_address = diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 6151dd8ff4c92bb81bd756c68cc9377633c8c9d5..f07141029cbf8b034b74548f6fca8f1628589f0c 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -282,22 +282,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, int64 operand_index) { - const HloInstruction* producer = consumer->operand(operand_index); - // The IR emitter has limited support for non-loop fusions with multi output - // at present. - // TODO(tjoerg): Relax this constraint to allow for arbitraty kinds of fusion. - if (consumer->opcode() == HloOpcode::kFusion && - consumer->fusion_kind() != HloInstruction::FusionKind::kLoop) { - return false; - } - // Multi-output fusion requires instructions with compatible shapes. - if (!ShapeUtil::Compatible(producer->shape(), consumer->shape())) { - return false; - } - // TODO(tjoerg): Stop calling `ShouldFuse` to relax the criteria for - // multi-output fusion. In particular, do not check whether an instruction is - // expensive to duplicate, since this doesn't matter here. - return GpuInstructionFusion::ShouldFuse(consumer, operand_index); + return false; } HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 688604cd36e5a45debf855aacd29d05ecda92341..a05ab86cf77a134a1fc387d93cb482aa1ff5345b 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -506,202 +506,11 @@ TEST_F(InstructionFusionTest, MultiOutputFusion) { })") .ValueOrDie(); - ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); - SCOPED_TRACE(module->ToString()); - - // Expect that there is one multi-output fusion and subtract has not been - // duplicated. - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); - EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); - TF_ASSERT_OK_AND_ASSIGN( - const HloInstruction* fusion, - FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); - EXPECT_THAT( - fusion->fused_expression_root(), - op::Tuple(op::Add(op::Subtract(), op::Parameter()), op::Subtract())); -} - -TEST_F(InstructionFusionTest, MultiOutputFusionExpensiveOp) { - // tanh --> add --> tuple - // \---------------/ - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - p0 = f32[4,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - tanh = f32[4,3]{1,0} tanh(p0) - add = f32[4,3]{1,0} add(tanh, p1) - ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(tanh, add) - })") - .ValueOrDie(); - - // TODO(tjoerg): Allow multi-output fusion for expensive operations like tanh. + // Multi-output fusion is disabled here and performed in the + // GpuMultiOutputFusion pass instead. ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) - .ValueOrDie()) - << module->ToString(); -} - -TEST_F(InstructionFusionTest, MultiOutputFusion2) { - // sub --> add1 --\--------\ - // \----------> add2 --> tuple - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - p0 = f32[4,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - p2 = f32[4,3]{1,0} parameter(2) - sub = f32[4,3]{1,0} subtract(p0, p2) - add1 = f32[4,3]{1,0} add(sub, p1) - add2 = f32[4,3]{1,0} add(sub, add1) - ROOT tuple = (f32[4,3]{1,0}) tuple(add1, add2) - })") - .ValueOrDie(); - - ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); - SCOPED_TRACE(module->ToString()); - - // Expect that there is one multi-output fusion and subtract has not been - // duplicated. - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); - EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); - TF_ASSERT_OK_AND_ASSIGN( - const HloInstruction* fusion, - FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); - EXPECT_THAT(fusion->fused_expression_root(), - op::Tuple(op::Add(op::Subtract(), op::Add()), - op::Add(op::Subtract(), op::Parameter()))); -} - -TEST_F(InstructionFusionTest, MultiOutputFusion3) { - // sub --> add1 ----\--------\ - // \ --> add2 --> add3 --> tuple - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - p0 = f32[4,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - p2 = f32[4,3]{1,0} parameter(2) - p3 = f32[4,3]{1,0} parameter(3) - sub = f32[4,3]{1,0} subtract(p0, p2) - add1 = f32[4,3]{1,0} add(sub, p1) - add2 = f32[4,3]{1,0} add(p2, sub) - add3 = f32[4,3]{1,0} add(add1, add2) - ROOT tuple = (f32[4,3]{1,0}) tuple(add3, add2) - })") - .ValueOrDie(); - - ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); - SCOPED_TRACE(module->ToString()); - - // Expect that there is one multi-output fusion and subtract has not been - // duplicated. - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); - EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); - TF_ASSERT_OK_AND_ASSIGN( - const HloInstruction* fusion, - FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); - EXPECT_THAT(fusion->fused_expression_root(), - op::Tuple(op::Add(op::Add(), op::Add()), - op::Add(op::Parameter(), op::Subtract()))); -} - -TEST_F(InstructionFusionTest, NoCyclesDueToMultiOutputFusion) { - // sub --> mul ---\ - // \--> call --> add --> tuple - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - c = f32[] constant(42) - p0 = f32[4,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - sub = f32[4,3]{1,0} subtract(p0, p1) - mul = f32[4,3]{1,0} multiply(sub, c) - call = f32[4,3]{1,0} custom-call(sub), custom_call_target="foo" - add = f32[4,3]{1,0} add(mul, call) - ROOT tuple = (f32[4,3]{1,0}) tuple(add) - })") - .ValueOrDie(); - - ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); - // Visit instructions in post order to detect cycles. - // TODO(tjoerg): Add cycle detection to the HloVerifier. - class DummyVisitor : public DfsHloVisitorWithDefault { - public: - DummyVisitor() {} - Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { - return Status::OK(); - } - } visitor; - for (const HloComputation* computation : module->MakeComputationPostOrder()) { - // Accept will return a FailedPrecondition when a cycle is detected. - EXPECT_TRUE(computation->root_instruction()->Accept(&visitor).ok()); - } -} - -TEST_F(InstructionFusionTest, NoMultiOutputFusionWithIncompatibleShapes) { - // sub[2,3] --> add[4,3] --> tuple([2,3], [4,3]) - // \-------------------------/ - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - p0 = f32[2,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - p2 = f32[2,3]{1,0} parameter(2) - sub = f32[2,3]{1,0} subtract(p0, p2) - add = f32[4,3]{1,0} add(sub, p1) - ROOT tuple = (f32[2,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add) - })") - .ValueOrDie(); - - // Multi-output fusion requires shapes to be compatible. Since `sub` and `add` - // have incompatible shapes, expect that no multi-output fusion happens. - ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()) - << module->ToString(); -} - -TEST_F(InstructionFusionTest, FuseIntoInputFusionInstruction) { - auto module = ParseHloString(R"( - HloModule test_module - - add_computation { - add_lhs = f32[] parameter(0) - add_rhs = f32[] parameter(1) - ROOT add_root = f32[] add(add_lhs, add_rhs) - } - - fused_computation { - p1 = f32[10] parameter(0) - zero = f32[] constant(0) - ROOT f2_root = f32[] reduce(p1, zero), dimensions={0}, - to_apply=add_computation - } - - ENTRY entry { - p0 = f32[10] parameter(0) - mul = f32[10] multiply(p0, p0) - fusion = f32[] fusion(mul), kind=kInput, calls=fused_computation - ROOT tuple = (f32[10], f32[]) tuple(fusion, mul) - })") - .ValueOrDie(); - - // Multi-output fusion is not supported for non-loop fusions at present. Since - // `fused_computation` is a input fusion, expect no multi-output fusion to - // happen. - ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()) - << module->ToString(); + .ValueOrDie()); } TEST_F(InstructionFusionTest, FuseScalarConstant) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 33e41a2782b5932430eea621d3cea2c6634f292f..82bdd677d96d3d0826bb4127b32d074eb632b1a3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -40,7 +40,7 @@ namespace { // Return whether the given shape is rank 2 excluding the batch dimensions. bool IsRank2(const Shape& shape, int64 batch_dimensions_size) { - return ShapeUtil::Rank(shape) == batch_dimensions_size + 2; + return shape.rank() == batch_dimensions_size + 2; } // In a gemm operation where output = lhs * rhs, check whether the given shapes @@ -54,7 +54,8 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, PrimitiveType output_primitive_type = output_shape.element_type(); bool type_is_allowed = (output_primitive_type == F16 || output_primitive_type == F32 || - output_primitive_type == F64 || output_primitive_type == C64); + output_primitive_type == F64 || output_primitive_type == C64 || + output_primitive_type == C128); return type_is_allowed && IsRank2(lhs_shape, batch_dimensions_size) && IsRank2(rhs_shape, batch_dimensions_size) && IsRank2(output_shape, batch_dimensions_size) && @@ -154,20 +155,17 @@ bool IsReductionToVector(const HloInstruction& reduce) { const HloInstruction* input = reduce.operand(0); std::vector dims_to_keep; for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) { - if (!std::count(reduce.dimensions().begin(), reduce.dimensions().end(), - dim)) { + if (!absl::c_linear_search(reduce.dimensions(), dim)) { dims_to_keep.push_back(dim); } } return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(), dims_to_keep) && - ShapeUtil::Equal(reduce.shape(), ShapeUtil::FilterDimensions( - [&dims_to_keep](int64 dim) { - return std::count( - dims_to_keep.begin(), - dims_to_keep.end(), dim); - }, - input->shape())); + ShapeUtil::Equal( + reduce.shape(), + ShapeUtil::FilterDimensions( + [&](int64 dim) { return absl::c_count(dims_to_keep, dim); }, + input->shape())); } // This emits a device-side call to diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 22db38ee03b9990cc2f21a01b6c0f2249d0991ea..0007a9a8a3369d8ac010640127e1561615a6d813 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -430,7 +430,7 @@ Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) { auto on_false = tuple_select->operand(2); TF_RET_CHECK(pred->shape().element_type() == PRED); TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape())); - TF_RET_CHECK(ShapeUtil::IsTuple(tuple_select->shape())); + TF_RET_CHECK(tuple_select->shape().IsTuple()); llvm_ir::EmitTupleSelect(GetIrArray(*tuple_select, *tuple_select), GetIrArray(*pred, *tuple_select), GetBasePointer(*on_true), GetBasePointer(*on_false), @@ -648,7 +648,7 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { Status IrEmitter::HandleReduce(HloInstruction* reduce) { // TODO(b/112040122): Support variadic reduce. - if (!ShapeUtil::IsArray(reduce->shape())) { + if (!reduce->shape().IsArray()) { return Unimplemented("Variadic reduce is not supported on GPU"); } auto arg = reduce->operand(0); @@ -783,7 +783,7 @@ StatusOr IrEmitter::ComputeNestedElement( std::vector IrEmitter::ConstructIrArrayForOutputs( const HloInstruction& hlo) { std::vector output_arrays; - if (ShapeUtil::IsTuple(hlo.shape())) { + if (hlo.shape().IsTuple()) { int64 num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); output_arrays.reserve(num_outputs); for (int64 i = 0; i < num_outputs; ++i) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 1472853dc443f0190c3bbed7f96c91ec65ae6dda..294a454931b5cfa368bf094c428a1e942f4556b8 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -88,6 +89,9 @@ namespace xla { namespace gpu { using llvm_ir::KernelMappingScheme; +using EmitElementFunction = + std::function; namespace { @@ -292,13 +296,12 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, auto shape_in_range = [&](const Shape& s) { bool in_range = true; - ShapeUtil::ForEachSubshape( - s, [&](const Shape& sub_shape, const ShapeIndex& /*index*/) { - if (ShapeUtil::IsArray(sub_shape) && - !IsInt32(ShapeUtil::ElementsIn(sub_shape))) { - in_range = false; - } - }); + ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape, + const ShapeIndex& /*index*/) { + if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) { + in_range = false; + } + }); return in_range; }; @@ -542,8 +545,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // HandleFusion specializes reduction from a multi-dimensional array to // a 1D array. The specialized version requires a initializer thunk that // initializes the output array to the initial value of the reduce. - if (root->opcode() == HloOpcode::kReduce && - ShapeUtil::IsTuple(root->shape())) { + if (root->opcode() == HloOpcode::kReduce && root->shape().IsTuple()) { // TODO(b/112040122): Support variadic reduce. return Unimplemented("Variadic reduce is not supported on GPU"); } @@ -634,7 +636,7 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { // TODO(b/112040122): Support multi-output reduce. - if (!ShapeUtil::IsArray(reduce->shape())) { + if (!reduce->shape().IsArray()) { return Unimplemented("Multi-output reduce is not supported on GPU"); } if (IsReductionToVector(*reduce)) { @@ -698,8 +700,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( const auto* source = select_and_scatter->operand(1); const Window& window = select_and_scatter->window(); PrimitiveType operand_element_type = operand->shape().element_type(); - const int64 rank = ShapeUtil::Rank(operand->shape()); - CHECK_EQ(rank, ShapeUtil::Rank(source->shape())); + const int64 rank = operand->shape().rank(); + CHECK_EQ(rank, source->shape().rank()); CHECK_EQ(rank, window.dimensions_size()); TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, @@ -1015,7 +1017,7 @@ Status IrEmitterUnnested::EmitScatter( int64 raw_window_multidim_idx = 0; std::vector input_window_multidim; std::vector input_window_bounds; - for (int64 i = 0, e = ShapeUtil::Rank(operand->shape()); i != e; ++i) { + for (int64 i = 0, e = operand->shape().rank(); i != e; ++i) { if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) { input_window_bounds.push_back(1); // Trivial dimension. input_window_multidim.push_back(index.GetConstantWithIndexType(0)); @@ -1027,12 +1029,11 @@ Status IrEmitterUnnested::EmitScatter( ++raw_window_multidim_idx; } } - DCHECK_EQ(input_window_multidim.size(), ShapeUtil::Rank(operand->shape())); + DCHECK_EQ(input_window_multidim.size(), operand->shape().rank()); // Insert a 1 dimension at the end if index_vector_dim requests one. Shape scatter_indices_shape = scatter_indices->shape(); - if (dim_numbers.index_vector_dim() == - ShapeUtil::Rank(scatter_indices_shape)) { + if (dim_numbers.index_vector_dim() == scatter_indices_shape.rank()) { scatter_indices_shape.add_dimensions(1); scatter_indices_shape.mutable_layout()->add_minor_to_major( dim_numbers.index_vector_dim()); @@ -1310,7 +1311,7 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { // HloModuleConfig::num_replicas changes between when the module is compiled // and when it's run. if (crs->operand_count() == 1) { - CHECK(ShapeUtil::IsArray(crs->operand(0)->shape())) + CHECK(crs->operand(0)->shape().IsArray()) << "Operands to all-reduce must be arrays: " << crs->ToString(); AddThunkToThunkSequence(absl::make_unique( /*source_address=*/GetAllocationSlice(*crs->operand(0)), @@ -1509,10 +1510,10 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( return !allocation->is_constant(); }); - std::sort(non_constant_buffers.begin(), non_constant_buffers.end(), - [](const BufferAllocation* a, const BufferAllocation* b) { - return a->index() < b->index(); - }); + absl::c_sort(non_constant_buffers, + [](const BufferAllocation* a, const BufferAllocation* b) { + return a->index() < b->index(); + }); llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers); @@ -2080,12 +2081,36 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( return Status::OK(); } +namespace { + +// Returns true if the fusion contains any instruction that is likely +// translated to complex LLVM IR, such as loops, and prevent vectorization. +bool MayPreventVectorization(const HloInstruction& fusion_hlo) { + CHECK_EQ(fusion_hlo.opcode(), HloOpcode::kFusion); + return absl::c_any_of( + fusion_hlo.fused_instructions_computation()->instructions(), + [&](const HloInstruction* instr) { + switch (instr->opcode()) { + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSort: + case HloOpcode::kDot: + return true; + default: + return false; + } + }); +} + +} // namespace + Status IrEmitterUnnested::EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator) { int unroll_factor = 1; // Unfused elementwise operations are usually memory bound, unroll them. - if (hlo.IsElementwise() || hlo.opcode() == HloOpcode::kFusion) { + if (hlo.IsElementwise() || + (hlo.opcode() == HloOpcode::kFusion && !MayPreventVectorization(hlo))) { unroll_factor = ComputeMaxUnrollFactor(&hlo); } @@ -2136,53 +2161,86 @@ int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( namespace { -void EmitFullElementalTile( - const KernelMappingScheme* mapping_scheme, - const IrArray::Index& tile_origin_index, const string& loop_name, - KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, - llvm::Value* x, llvm::Type* index_ty, - const std::function& emit_elem_function) { +std::tuple GetStartOffsetAndStepForX( + int64 tile_size_x, int64 num_threads_x, + const KernelMappingScheme* mapping_scheme, llvm::IRBuilder<>* builder, + llvm::Value* x, llvm::Type* index_ty) { + llvm::Value* start_offset_x; + int64 step_x; + if (mapping_scheme->DilatedX()) { + start_offset_x = x; + step_x = num_threads_x; + } else { + start_offset_x = builder->CreateMul( + x, llvm::ConstantInt::get(index_ty, tile_size_x / num_threads_x)); + step_x = 1; + } + return std::make_tuple(start_offset_x, step_x); +} + +void EmitFullElementalTile(const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, + const string& loop_name, KernelSupportLibrary* ksl, + llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Type* index_ty, + const EmitElementFunction& emit_elem_function) { int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); + + llvm::Value* start_offset_x; + int64 step_x; + std::tie(start_offset_x, step_x) = GetStartOffsetAndStepForX( + tile_size_x, num_threads_x, mapping_scheme, builder, x, index_ty); + IrArray::Index source_idx = + tile_origin_index.AddOffsetToDim(y, KernelMappingScheme::DimY, builder) + .AddOffsetToDim(start_offset_x, KernelMappingScheme::DimX, builder); ksl->For(loop_name + "_y", /*start=*/llvm::ConstantInt::get(index_ty, 0), /*end=*/llvm::ConstantInt::get(index_ty, tile_size_y), /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), [&](llvm::Value* y_indvar) { - IrArray::Index source_idx_y = tile_origin_index.AddOffsetToDim( + IrArray::Index source_idx_y = source_idx.AddOffsetToDim( y_indvar, KernelMappingScheme::DimY, builder); llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); - for (int64 j = 0; j < tile_size_x; j += num_threads_x) { - IrArray::Index source_idx = source_idx_y.AddOffsetToDim( - llvm::ConstantInt::get(index_ty, j), + + for (int64 j = 0; j < tile_size_x / num_threads_x; j++) { + IrArray::Index source_idx_y_x = source_idx_y.AddOffsetToDim( + llvm::ConstantInt::get(index_ty, j * step_x), KernelMappingScheme::DimX, builder); - llvm::Value* x_loc = - builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); - emit_elem_function(source_idx, y_loc, x_loc); + llvm::Value* x_loc = builder->CreateAdd( + llvm::ConstantInt::get(index_ty, j * step_x), + start_offset_x); + emit_elem_function(source_idx_y_x, y_loc, x_loc, j); } }); } -void EmitPartialElementalTile( - const KernelMappingScheme* mapping_scheme, - const IrArray::Index& tile_origin_index, const string& loop_name, - KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, - llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, - llvm::Type* index_ty, - const std::function& emit_elem_function) { +void EmitPartialElementalTile(const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, + const string& loop_name, + KernelSupportLibrary* ksl, + llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Value* tile_height, + llvm::Value* tile_width, llvm::Type* index_ty, + const EmitElementFunction& emit_elem_function) { int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); - for (int64 j = 0; j < tile_size_x; j += num_threads_x) { - IrArray::Index source_idx = - tile_origin_index.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j), - KernelMappingScheme::DimX, builder); - llvm::Value* x_loc = - builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); + llvm::Value* start_offset_x; + int64 step_x; + std::tie(start_offset_x, step_x) = GetStartOffsetAndStepForX( + tile_size_x, num_threads_x, mapping_scheme, builder, x, index_ty); + IrArray::Index source_idx = + tile_origin_index.AddOffsetToDim(y, KernelMappingScheme::DimY, builder) + .AddOffsetToDim(start_offset_x, KernelMappingScheme::DimX, builder); + for (int64 j = 0; j < tile_size_x / num_threads_x; j++) { + IrArray::Index source_idx_x = + source_idx.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j * step_x), + KernelMappingScheme::DimX, builder); + llvm::Value* x_loc = builder->CreateAdd( + llvm::ConstantInt::get(index_ty, j * step_x), start_offset_x); ksl->If( loop_name + "_x_in_tile", builder->CreateICmpULT(x_loc, tile_width), @@ -2202,14 +2260,13 @@ void EmitPartialElementalTile( /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), [&](llvm::Value* y_indvar) { llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); - ksl->If( - loop_name + "_y_in_tile", - builder->CreateICmpULT(y_loc, tile_height), [&] { - emit_elem_function( - source_idx.AddOffsetToDim( - y_indvar, KernelMappingScheme::DimY, builder), - y_loc, x_loc); - }); + ksl->If(loop_name + "_y_in_tile", + builder->CreateICmpULT(y_loc, tile_height), [&] { + emit_elem_function( + source_idx_x.AddOffsetToDim( + y_indvar, KernelMappingScheme::DimY, builder), + y_loc, x_loc, j); + }); }); }); } @@ -2228,8 +2285,7 @@ void EmitTiledElementalCodeWithBoundsCheck( const IrArray::Index& tile_origin_index, const string& loop_name, KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, - const std::function& emit_elem_function) { + const EmitElementFunction& emit_elem_function) { int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); llvm::Type* index_ty = tile_width->getType(); @@ -2265,7 +2321,7 @@ void EmitTiledElementalCodeWithBoundsCheck( void IrEmitterUnnested::EmitTileElementForCopy( HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc) { + llvm::Value* x_loc, int64 /*x_iter_num*/) { llvm_ir::TiledParameterInfo* tiled_param_info = kernel_info->GetTiledParameterInfo(); // TODO(jlebar): Add AA metadata to this load. @@ -2295,7 +2351,7 @@ void IrEmitterUnnested::EmitTileElementForCopy( void IrEmitterUnnested::EmitTileElementForFusion( HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc) { + llvm::Value* x_loc, int64 /*x_iter_num*/) { llvm_ir::TiledParameterInfo* tiled_param_info = kernel_info->GetTiledParameterInfo(); std::vector output_arrays = ConstructIrArrayForOutputs(*hlo); @@ -2396,6 +2452,23 @@ class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo { : llvm_ir::KernelMappingScheme::DimX; } + int GetNumberOfPartialResults() const { + if (IsRowReduction()) { + return 1; + } + int64 num_thread = mapping_scheme_->GetNumberOfThreadsForDimensionX(); + int64 tile_size = mapping_scheme_->GetTileSizeForDimensionX(); + CHECK_EQ(tile_size % num_thread, 0); + return tile_size / num_thread; + } + + int GetPartialResultIndex(int64 x_iter_num) const { + if (IsRowReduction()) { + return 0; + } + return x_iter_num; + } + private: AddressVector partial_result_addresses_; AddressVector reduction_input_addresses_; @@ -2455,10 +2528,11 @@ void IrEmitterUnnested::EmitPrologueForOneReduction( llvm::AllocaInst* reduction_input_address = Alloca(element_type); reduction_input_addresses->push_back(reduction_input_address); + int num_partial_results = reduction_info->GetNumberOfPartialResults(); AddressVector* partial_result_addresses = reduction_info->GetMutablePartialResultAddresses(); llvm::AllocaInst* partial_result_address = - Alloca(element_type, /*ArraySize=*/nullptr, + Alloca(element_type, /*ArraySize=*/b_.getInt32(num_partial_results), "partial_reduction_result." + llvm::Twine(reduce_idx)); partial_result_addresses->push_back(partial_result_address); @@ -2481,7 +2555,9 @@ void IrEmitterUnnested::EmitPrologueForOneReduction( .EmitReadArrayElement(IrArray::Index(b_.getInt32Ty()), &b_); } - Store(init_ir_value, partial_result_address); + for (int i = 0; i < num_partial_results; ++i) { + Store(init_ir_value, InBoundsGEP(partial_result_address, {b_.getInt32(i)})); + } } void IrEmitterUnnested::EmitPrologueForReduction( @@ -2519,10 +2595,14 @@ void IrEmitterUnnested::EmitPrologueForReduction( std::move(output_shape_index)); } - // Allocate stack storage to store the current output linear index and record - // the address of the storage. + int num_partial_results = reduction_info->GetNumberOfPartialResults(); + + // Allocate stack storage to store the linear indices for the current output, + // and record the address of the storage. reduction_info->SetCurrentOutputLinearIndexAddress( - Alloca(reduction_info->GetIndexType())); + Alloca(reduction_info->GetIndexType(), + /*ArraySize=*/b_.getInt32(num_partial_results), + "current_output_linear_index_address")); if (!reduction_info->IsRowReduction()) { llvm::Type* bool_ty = b_.getInt1Ty(); @@ -2592,36 +2672,45 @@ void IrEmitterUnnested::EmitEpilogueForReduction( llvm_ir::SetToFirstInsertPoint(if_output_inbound_data.true_block, &b_); } + int num_partial_results = reduction_info->GetNumberOfPartialResults(); + // Emit an atomic operation that accumulates the partial reduction to the // output element. For row reduction, this is only for lane 0 due to the // if-statement emitted above. for (int i = 0; i != num_reduces; ++i) { - IrArray::Index element_index( - /*linear=*/Load(reduction_info->GetCurrentOutputLinearIndexAddress(), - "output_linear_addr"), - ShapeUtil::GetSubshape(unnested_hlo->shape(), - reduction_output_shape_indices[i]), - &b_); - llvm::Value* output_address = - GetIrArray(*unnested_hlo, *unnested_hlo, - reduction_output_shape_indices[i]) - .EmitArrayElementAddress(element_index, &b_, - "output_element_address"); - // Do not emit atomic operations if each element in the reduction result is - // computed by one block, that is the dimension being reduced has only one - // block. - const llvm_ir::KernelMappingScheme* mapping_scheme = - reduction_info->GetKernelMappingScheme(); - if (mapping_scheme->GetTileBlockSizeForDimension( - llvm_ir::KernelMappingScheme::DimZ) == 1 && - mapping_scheme->GetTileBlockSizeForDimension( - reduction_info->GetReducedDimensionEnum()) == 1) { - TF_CHECK_OK(EmitCallToNestedComputation( - *reducers[i], {output_address, partial_result_addresses[i]}, - output_address)); - } else { - TF_CHECK_OK(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, partial_result_addresses[i])); + for (int j = 0; j < num_partial_results; ++j) { + IrArray::Index element_index( + /*linear=*/Load( + InBoundsGEP(reduction_info->GetCurrentOutputLinearIndexAddress(), + {b_.getInt32(j)}), + "output_linear_addr"), + ShapeUtil::GetSubshape(unnested_hlo->shape(), + reduction_output_shape_indices[i]), + &b_); + llvm::Value* output_address = + GetIrArray(*unnested_hlo, *unnested_hlo, + reduction_output_shape_indices[i]) + .EmitArrayElementAddress(element_index, &b_, + "output_element_address"); + // Do not emit atomic operations if each element in the reduction result + // is computed by one block, that is the dimension being reduced has only + // one block. + const llvm_ir::KernelMappingScheme* mapping_scheme = + reduction_info->GetKernelMappingScheme(); + if (mapping_scheme->GetTileBlockSizeForDimension( + llvm_ir::KernelMappingScheme::DimZ) == 1 && + mapping_scheme->GetTileBlockSizeForDimension( + reduction_info->GetReducedDimensionEnum()) == 1) { + TF_CHECK_OK(EmitCallToNestedComputation( + *reducers[i], + {output_address, + InBoundsGEP(partial_result_addresses[i], {b_.getInt32(j)})}, + output_address)); + } else { + TF_CHECK_OK(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, + InBoundsGEP(partial_result_addresses[i], {b_.getInt32(j)}))); + } } } } @@ -2629,7 +2718,7 @@ void IrEmitterUnnested::EmitEpilogueForReduction( void IrEmitterUnnested::EmitTileElementForReduction( HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc) { + llvm::Value* x_loc, int64 x_iter_num) { VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString(); HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion ? unnested_hlo->fused_expression_root() @@ -2642,8 +2731,11 @@ void IrEmitterUnnested::EmitTileElementForReduction( // Record the linear address for the current reduction. const ReductionCodegenInfo* reduction_info = dynamic_cast(kernel_info); + int partial_result_index = reduction_info->IsRowReduction() ? 0 : x_iter_num; + Store(index[reduction_info->GetKeptDimensionEnum()], - reduction_info->GetCurrentOutputLinearIndexAddress()); + InBoundsGEP(reduction_info->GetCurrentOutputLinearIndexAddress(), + {b_.getInt32(partial_result_index)})); if (!reduction_info->IsRowReduction()) { llvm::Type* bool_ty = b_.getInt1Ty(); llvm::AllocaInst* output_inbound_addr = @@ -2690,6 +2782,13 @@ void IrEmitterUnnested::EmitTileElementForReduction( reduction_info->GetKernelMappingScheme()->GetUnnormalizedIndex( index, GetFirstReduceInstruction(output_instructions)->operand(0)->shape()); + int num_partial_results = reduction_info->GetNumberOfPartialResults(); + if (num_partial_results > 1) { + // Clear the linear index field of the IrArray::Index to enable the use of + // GetElementPointer with array types. This enables the vectorization of + // the computation for different partial results. + input_index.ClearLinearIndex(); + } absl::Span partial_reduction_result_addresses = reduction_info->GetPartialResultAddresses(); absl::Span reduction_input_addresses = @@ -2702,10 +2801,12 @@ void IrEmitterUnnested::EmitTileElementForReduction( for (int i = 0; i != reducers.size(); ++i) { llvm::Value* const input_ir_value = input_gens[i](input_index).ValueOrDie(); Store(input_ir_value, reduction_input_addresses[i]); + llvm::Value* partial_result_address = + InBoundsGEP(partial_reduction_result_addresses[i], + {b_.getInt32(partial_result_index)}); TF_CHECK_OK(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], reduction_input_addresses[i]}, - partial_reduction_result_addresses[i])); + *reducers[i], {partial_result_address, reduction_input_addresses[i]}, + partial_result_address)); } // Emit code to generate the output for the non-reduction instructions in the @@ -2716,8 +2817,8 @@ void IrEmitterUnnested::EmitTileElementForReduction( // Emits a kernel for the hlo instruction using the given tiling scheme. void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, - const KernelCodegenInfo* kernel_info, - KernelSupportLibrary& ksl, + KernelCodegenInfo* kernel_info, + KernelSupportLibrary* ksl, llvm::Type* index_ty) { KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme(); absl::Span dims_in_tile = mapping_scheme->GetDimensionsInTiles(); @@ -2750,15 +2851,14 @@ void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, llvm::Value* num_tiles_in_block = Select(ICmpEQ(last_block_for_dim, block_id_for_dim), last_block_size_for_dim, block_size_for_dim); - - ksl.For(loop_name, - /*start=*/index_typed_constant(0), - /*end=*/num_tiles_in_block, - /*step=*/1, [&](llvm::Value* block_dim_induction_var) { - IrArray::Index tile_index = starting_tile.AddOffsetToDim( - block_dim_induction_var, dim_id, &b_); - emit_next_block_dim(tile_index); - }); + ksl->For(loop_name, + /*start=*/index_typed_constant(0), + /*end=*/num_tiles_in_block, + /*step=*/1, [&](llvm::Value* block_dim_induction_var) { + IrArray::Index tile_index = starting_tile.AddOffsetToDim( + block_dim_induction_var, dim_id, &b_); + emit_next_block_dim(tile_index); + }); } }; @@ -2813,7 +2913,8 @@ void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, // unnested_hlo: The unnested hlo instruction for which the kernel is generated. // Currently, these hlo instructions are supported: kLoop fusion, kCopy. // tiled_param_ids: The IDs for the parameters that are 0-2-1 transpose of -// other tensors with the same dimensions and need to be tiled and tranposed. +// other tensors with the same dimensions and are safe to be tranposed via +// the shared memory tranpose implementation. // mapping_scheme: The tiling scheme to use. // kernel_generator: Contains function objects for code generation, such as // element generator, block prologue and epilogue generators. @@ -2901,8 +3002,7 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( auto emit_tiled_elemental_code_with_bounds_check = [&](const IrArray::Index& index, const string& loop_name, llvm::Value* tile_height, llvm::Value* tile_width, - const std::function& emit_elem_function) { + const EmitElementFunction& emit_elem_function) { EmitTiledElementalCodeWithBoundsCheck(mapping_scheme, index, loop_name, &ksl, &b_, y, x, tile_height, tile_width, emit_elem_function); @@ -2915,10 +3015,6 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( const IrArray::Index input_tile_origin( Permute({0, 2, 1}, output_tile_origin.multidim())); - const IrArray::Index input_index = - input_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_) - .AddOffsetToDim(y, KernelMappingScheme::DimY, &b_); - // If shared memory transpose is needed, wait for all threads to reach this // point, lest we copy a value from tile to output before the other thread // copies it from input to tile. This is `__syncthreads` in CUDA. @@ -2928,9 +3024,10 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( // Note that tile_width and tile_height are flipped here because we are // reading a transposed tile. emit_tiled_elemental_code_with_bounds_check( - input_index, "input", output_tile_bounds[2], output_tile_bounds[1], + input_tile_origin, "input", output_tile_bounds[2], + output_tile_bounds[1], [&](const IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { + llvm::Value* x_loc, int64 /*x_iter_num*/) { for (int64 id : tiled_param_ids) { IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id]; @@ -2950,18 +3047,15 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x); kernel_info->SetTiledParamInfo(&tiled_param_info); - const IrArray::Index output_index = - output_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_) - .AddOffsetToDim(y, KernelMappingScheme::DimY, &b_); - // Write to output[index] by emitting code like normal, except that values // for the tiled parameters are read from the shmem buffers. emit_tiled_elemental_code_with_bounds_check( - output_index, "output", output_tile_bounds[1], output_tile_bounds[2], - [&](const IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { - kernel_generator.GetTileElementGenerator()(unnested_hlo, index, - kernel_info, y_loc, x_loc); + output_tile_origin, "output", output_tile_bounds[1], + output_tile_bounds[2], + [&](const IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num) { + kernel_generator.GetTileElementGenerator()( + unnested_hlo, index, kernel_info, y_loc, x_loc, x_iter_num); }); // If a tile block contains multiple tiles and shared memory buffers are @@ -2979,7 +3073,7 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( block_prologue_generator(unnested_hlo, kernel_info); } - EmitBlock(std::move(emit_one_tile), kernel_info, ksl, index_ty); + EmitBlock(std::move(emit_one_tile), kernel_info, &ksl, index_ty); const BlockEpilogueGenerator& block_epilogue_generator = kernel_generator.GetBlockEpilogueGenerator(); @@ -2992,7 +3086,10 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose // algorithm to improve the memory access patterns for the input parameters -// with a shape that is a 0-2-1 transpose of the output tensor shape. +// with a shape that is a 0-2-1 transpose of the output tensor shape. The caller +// is responsible for making sure that it is safe to apply the shared memory +// tranpose on the input parameters. +// // // For the purpose of tiling, the output tensors have a logical shape of three // components 0-2-1 while the relevant input parameters have a logical shape @@ -3025,17 +3122,19 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( element_generator = [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc) { - EmitTileElementForCopy(hlo, index, kernel_info, y_loc, x_loc); + llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num) { + EmitTileElementForCopy(hlo, index, kernel_info, y_loc, x_loc, x_iter_num); }; } else { DCHECK_EQ(hlo->opcode(), HloOpcode::kFusion); - element_generator = [&](HloInstruction* hlo, - const llvm_ir::IrArray::Index& index, - const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc) { - EmitTileElementForFusion(hlo, index, kernel_info, y_loc, x_loc); - }; + element_generator = + [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc, int64 x_iter_num) { + EmitTileElementForFusion(hlo, index, kernel_info, y_loc, x_loc, + x_iter_num); + }; } KernelCodegenInfo kernel_info(&mapping_scheme); KernelCodeGenerator kernel_generator(std::move(element_generator)); @@ -3043,26 +3142,99 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( } namespace { -// Returns true to indicate it is safe to use the tile based shared memory -// transpose implementation to implement the kernel for the instruction. +// A recursive function to inspect the users of a parameter to determine +// whether it's safe for a parameter to participate in a shared-memory +// transpose. // -// An instruction is not safe for such an implementation if it can change the -// element order of a tensor without changing the dimension of the tensor, and -// the instruction has a corresponding elemental_ir_emitter. -bool IsInstructionSafeForTileBasedTranspose(const HloInstruction* hlo) { - auto is_safe_for_tile_based_transpose = [&](const HloInstruction* instr) { - HloOpcode opcode = instr->opcode(); - CHECK_NE(opcode, HloOpcode::kFusion); - return (opcode != HloOpcode::kReverse && opcode != HloOpcode::kGather); - }; +// Consider a fusion parameter P for which we might want to use a shmem +// transpose. If we do, we use a GPU thread block to preload a tile of P with +// indices [z, y..y+31, x..x+31] to compute an output tile with the same indices +// cooperatively, where z, y, x are the indices for the normalized input/output +// tensor (see the document for FindTranspose021 for the definition of +// normalized tensor for 0-2-1 transpose). This shmem transpose implementation +// requires that the computation of the output tile only read elements within +// the preload tile. If this is not true, we can't use a shmem transpose for P. +// +// If the computation of output element [z, y, x] only requires the element of +// P with the same indices, the shmem tranpose implementation can be applied +// to P safely. This is a sufficient but not necessary condition. We check all +// the transitive users of P to see if we can find a user that may cause an +// exception to the situation. If such a user is not found, we conclude that P +// is safe for shmem transpose. +// +// This is trivially true for elementwise operations and some "data-movement" +// ops like kTuple. However, it's not true for operations that can change the +// dimensions of the inputs (e.g. pad, slice) and bitcast operation. +// For example: +// +// fused_computation { +// param_0 = f32[64,64]{1,0} parameter(0) +// ROOT bitcast = f32[64,64]{0,1} bitcast(param_0) +// } +// The output element at logical address [0, 63] depends on the input element +// at logical address [63, 0], which would not be within the shared-memory +// block. +// +// TODO(bixia): In order to extend this for kInput fusion, that is reduction +// with tranpose, we only need to end the use-chain checking with the input of +// a reduce operations. In this case, the above description on "output" apply +// to the result of such a use-chain, which provides the input to the reduce +// operation. +bool IsInstructionSafeForShmemTranspose(const HloInstruction* hlo) { + if (hlo->IsElementwise()) { + return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) { + return IsInstructionSafeForShmemTranspose(user); + }); + } + + switch (hlo->opcode()) { + // Non-elementwise instructions that don't cause the shmem transpose + // to be unsafe, including the instructions that don't currently fuse. + case HloOpcode::kGetDimensionSize: + // The result of the operation doesn't rely on the content of the + // tensor. As such, there is no need to further inspect its users. + return true; + case HloOpcode::kGetTupleElement: + case HloOpcode::kMap: + case HloOpcode::kParameter: + case HloOpcode::kTuple: + case HloOpcode::kTupleSelect: + return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) { + return IsInstructionSafeForShmemTranspose(user); + }); - if (hlo->opcode() == HloOpcode::kFusion) { - return absl::c_all_of(hlo->fused_instructions_computation()->instructions(), - is_safe_for_tile_based_transpose); + default: + return false; } +} - return is_safe_for_tile_based_transpose(hlo); +// Given a group of input parameters that are 0-2-1 tranpose of the outputs of +// a fusion kernel, returns the input parameters that are safe for the shared +// memory tranpose implementation. +// +// When a tile based shared memory transpose is used to implement an input with +// 0-2-1 transpose, we preload a tile of the input elements +// [z, y..y+31, x..x+31] to compute the output tile elements of the same +// indices. Preloading the input tile this way is only safe when the computation +// of the output tile elements do not need any input element outside the +// preloaded tile. We inspect all the transitive users of the input parameter +// up to the fusion root instruction to see if we can find any instruction +// that can make preloading the input tile unsafe. +std::vector FilterInputsForShmemTranspose(const HloInstruction* fusion, + std::vector input_ids) { + std::vector filtered_input_ids; + for (int64 i = 0; i < input_ids.size(); ++i) { + const HloInstruction* input = fusion->fused_parameter(input_ids[i]); + if (IsInstructionSafeForShmemTranspose(input)) { + filtered_input_ids.push_back(input_ids[i]); + } else { + VLOG(10) << "Input not safe for shmem transpose " << input->ToString() + << "\n"; + } + } + return filtered_input_ids; } + } // namespace bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { @@ -3109,8 +3281,11 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { return false; } - if (!IsInstructionSafeForTileBasedTranspose(hlo)) { - return false; + if (opcode == HloOpcode::kFusion) { + params_012 = FilterInputsForShmemTranspose(hlo, params_012); + if (params_012.empty()) { + return false; + } } // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the @@ -3191,7 +3366,7 @@ Status AreFusedReductionOutputsConsistent( // dimensions from minor to major. DimensionVector GetDimensionsToKeepMinorToMajor( const Shape& input_shape, absl::Span dims_to_reduce) { - DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0); + DimensionVector input_dims(input_shape.rank(), 0); absl::c_iota(input_dims, 0); DimensionVector input_dims_to_keep; for (int input_dim : input_dims) { @@ -3231,7 +3406,7 @@ std::tuple GetReductionToVectorDimensions( if (input_dims_to_keep_minor_to_major.empty()) { return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor); } - DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0); + DimensionVector input_dims(input_shape.rank(), 0); absl::c_iota(input_dims, 0); absl::Span minor_to_major = LayoutUtil::MinorToMajor(input_shape); @@ -3253,11 +3428,101 @@ std::tuple GetReductionToVectorDimensions( return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor); } +// Returns true if all the transitive users of hlo before hitting users in +// use_chain_endings are elementwise operations. +bool AreUsersElementwise(const HloInstruction* hlo, + const ConstHloInstructionSet& use_chain_endings) { + return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) { + return use_chain_endings.count(user) || + (user->IsElementwise() && + AreUsersElementwise(user, use_chain_endings)); + }); +} + +// Returns the number of fusion inputs that have the same dimension as the +// given shape, and involve in only elementwise operations. +int64 NumInputsInvolveInOnlyElementwiseOps( + const HloInstruction* unnested_hlo, const Shape& op_shape, + const ConstHloInstructionSet& use_chain_endings) { + return absl::c_count_if( + unnested_hlo->fused_parameters(), [&](const HloInstruction* parameter) { + const Shape& parameter_shape = parameter->shape(); + return ShapeUtil::SameDimensions(op_shape, parameter_shape) && + AreUsersElementwise(parameter, use_chain_endings); + }); +} + +// Returns the number of fusion inputs that have more elements than the given +// shape. +int64 NumInputsWithMoreElementsThan(const HloInstruction* unnested_hlo, + const Shape& shape) { + int64 num_elements = ShapeUtil::ElementsIn(shape); + return absl::c_count_if( + unnested_hlo->fused_parameters(), [&](const HloInstruction* parameter) { + return ShapeUtil::ElementsIn(parameter->shape()) > num_elements; + }); +} + +// The benefit of unrolling a kInput fusion that is a column reduction comes +// from the vectorization of non-reduction fusion outputs and fusion inputs. +// On the other hand, unrolling can also introduce factors that can cause +// the kernel to run slower. This routine uses a simple heuristic to estimate +// the benefit as well as the overhead of unrolling in order to decide whether +// unrolling is beneficial for the given kInput fusion. +bool IsUnrollingColumnReductionBeneficial(const HloInstruction* unnested_hlo, + const Shape& input_shape, + int64 num_kept) { + // TODO(b/122468062): Need further investigate to see whether we can + // remove the constraint on IsPowerOfTwo. + if (!IsPowerOfTwo(static_cast(num_kept))) { + return false; + } + + if (unnested_hlo->opcode() == HloOpcode::kReduce) { + return true; + } + + CHECK_EQ(unnested_hlo->opcode(), HloOpcode::kFusion); + int64 can_be_vectorized = 0; + int64 cannot_be_vectorized = 0; + const HloInstruction* fused_root = unnested_hlo->fused_expression_root(); + ConstHloInstructionSet use_chain_endings; + if (fused_root->opcode() == HloOpcode::kReduce) { + use_chain_endings.insert(fused_root); + // Atomic.add of the reduction result can't be vectorized. + cannot_be_vectorized++; + } else { + CHECK_EQ(fused_root->opcode(), HloOpcode::kTuple); + for (const HloInstruction* instr : fused_root->operands()) { + if (instr->opcode() == HloOpcode::kReduce) { + // Atomic.add of the reduction result can't be vectorized. + cannot_be_vectorized++; + } else { + // Write of the non-reduction result can be vectorized. + can_be_vectorized++; + } + use_chain_endings.insert(instr); + } + } + // Fusion inputs that have the same dimension as the reduce input and + // only involve in elementwise operations can be vectorized. + can_be_vectorized += NumInputsInvolveInOnlyElementwiseOps( + unnested_hlo, input_shape, use_chain_endings); + // Fusion inputs with more elements than the reduce op input must participate + // in non-elementwise operations and we assume that they are not vectorizable + // for the purpose of estimating the benefit of unrolling. If the kernel is + // unrolled even with such an assumption, and the accesses to those inputs + // turn out to be vectorizable, the compiler will still vectorize them. + cannot_be_vectorized += + NumInputsWithMoreElementsThan(unnested_hlo, input_shape); + return can_be_vectorized >= cannot_be_vectorized; +} + } // namespace std::tuple IrEmitterUnnested::ComputeMappingSchemeAndReductionKind( - const HloInstruction* first_reduce) { + const HloInstruction* unnested_hlo, const HloInstruction* first_reduce) { int64 depth = 1; int64 height = 1; int64 width = 1; @@ -3274,6 +3539,7 @@ IrEmitterUnnested::ComputeMappingSchemeAndReductionKind( std::tie(num_reduced_major, num_kept, num_reduced_minor) = GetReductionToVectorDimensions(input_shape, first_reduce->dimensions()); CHECK_EQ(num_output_elems, num_kept); + bool dilated_x = true; if (num_kept == 1) { // Scalar reduction is a special row reduction with depth = height = 1. @@ -3288,13 +3554,21 @@ IrEmitterUnnested::ComputeMappingSchemeAndReductionKind( is_row_reduction = false; // Column reduction without transpose doesn't require communication among // threads processing elements in the same tile. The current implementation - // only support the use of on hardware thread block to process one block of - // tiles in the KernelMappingScheme. We try to maximize the values of + // only support the use of one hardware thread block to process one block of + // tiles in the KernelMappingScheme. We try to use one thread to compute + // the partial results for two tensor elements and to maximize the values of // num_threads_x and tile_size_x to allow a bigger hardware thread block. int64 hw_threads_per_block_limit = ThreadsPerBlockLimit(ir_emitter_context_->device_description()); - tile_size_x = std::min(hw_threads_per_block_limit, num_kept); - num_threads_x = tile_size_x; + if (IsUnrollingColumnReductionBeneficial(unnested_hlo, input_shape, + num_kept)) { + tile_size_x = std::min(2 * hw_threads_per_block_limit, num_kept); + num_threads_x = tile_size_x / 2; + dilated_x = false; + } else { + tile_size_x = std::min(hw_threads_per_block_limit, num_kept); + num_threads_x = tile_size_x; + } int64 kNumElementsPerPartialSum = 128; tile_size_y = kNumElementsPerPartialSum; } else { @@ -3323,6 +3597,7 @@ IrEmitterUnnested::ComputeMappingSchemeAndReductionKind( llvm_ir::KernelMappingScheme mapping_scheme( dims_in_elem, tile_size_y, tile_size_x, req_block_sizes, num_threads_y, num_threads_x, &b_); + mapping_scheme.SetDilatedX(dilated_x); return std::make_tuple(mapping_scheme, is_row_reduction); } @@ -3371,14 +3646,15 @@ Status IrEmitterUnnested::EmitReductionToVector(HloInstruction* unnested_hlo) { bool is_row_reduction; llvm_ir::KernelMappingScheme mapping_scheme; std::tie(mapping_scheme, is_row_reduction) = - ComputeMappingSchemeAndReductionKind(first_reduce); + ComputeMappingSchemeAndReductionKind(unnested_hlo, first_reduce); ReductionCodegenInfo reduction_info(&mapping_scheme, is_row_reduction); KernelCodeGenerator kernel_generator( /*tile_element_generator=*/ [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc) { - EmitTileElementForReduction(hlo, index, kernel_info, y_loc, x_loc); + llvm::Value* x_loc, int64 x_iter_num) { + EmitTileElementForReduction(hlo, index, kernel_info, y_loc, x_loc, + x_iter_num); }, /*block_prologue_generator=*/ [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index d217ee36cf6e9b5278024a2f78513232328e7538..21b842bb2cd63ac454f85556df20ae5877cecbe1 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -76,7 +76,6 @@ class IrEmitterUnnested : public IrEmitter { void SetLaneId(llvm::Value* v) { lane_id_ = v; } void SetIndexType(llvm::Type* t) { index_ty_ = t; } void SetTiledParamInfo(llvm_ir::TiledParameterInfo* tiled_param_info) { - CHECK_EQ(tiled_param_info_, nullptr); tiled_param_info_ = tiled_param_info; } @@ -89,7 +88,7 @@ class IrEmitterUnnested : public IrEmitter { } llvm::Type* GetIndexType() const { return index_ty_; } - private: + protected: llvm_ir::KernelMappingScheme* mapping_scheme_; llvm_ir::TiledParameterInfo* tiled_param_info_; llvm::Value* lane_id_; @@ -109,10 +108,12 @@ class IrEmitterUnnested : public IrEmitter { // y_loc: The y coordinate within a tile. // x_loc: The x coordinate within a tile. // kernel_info: Other information to support the kernel code generation. + // x_iter_num: When a thread process N elements in the X dimension, x_iter_num + // has a value of 0..N-1 to identify the element being process. using TileElementGenerator = std::function; + llvm::Value* x_loc, int64 x_iter_num)>; // KernelCodeGenerator records the code generator objects that generate code // for tile elements or tile block prologue/epilogue. @@ -216,9 +217,13 @@ class IrEmitterUnnested : public IrEmitter { Status EmitReductionToVector(HloInstruction* unnested_hlo); // Computes the KernelMappingScheme for the reduce HLO and indicates whether - // the reduction is a row reduction. + // the reduction is a row reduction. For an un-fused reduce op, unnested_hlo + // and first_reduce are the same instruction. For a kInput fusion, + // unnested_hlo is the fusion instruction while first_reduce is the first + // reduce op. std::tuple - ComputeMappingSchemeAndReductionKind(const HloInstruction* first_reduce); + ComputeMappingSchemeAndReductionKind(const HloInstruction* unnested_hlo, + const HloInstruction* first_reduce); // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in // the process. `scatter` may be fused, scatter indices are taken from @@ -243,26 +248,29 @@ class IrEmitterUnnested : public IrEmitter { const KernelCodeGenerator& kernel_generator, KernelCodegenInfo* kernel_info); void EmitBlock(const TileGenerator& emit_one_tile, - const KernelCodegenInfo* kernel_info, - KernelSupportLibrary& ksl, llvm::Type* index_ty); + KernelCodegenInfo* kernel_info, KernelSupportLibrary* ksl, + llvm::Type* index_ty); // Emits code to process a tensor element in a tile for the given kCopy HLO // that performs a 0-2-1 transpose. void EmitTileElementForCopy(HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc); + llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num); // Emits code to process a tensor element in a tile for the given kLoop fusion // HLO containing parameters that are 0-2-1 transpose of its outputs. void EmitTileElementForFusion(HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc); + llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num); // Emits code to process a tensor element in a tile for the given input hlo // that is either a unnested kReduce or a kInput fusion. void EmitTileElementForReduction(HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc); + llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num); // Prepares for the code generation for a tile block of a reduction kernel. void EmitPrologueForReduction(HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info); diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index bd53b90b42d8e657a3ee58e7ca03fb60522aae28..153aab97d9eb971734c5ea95564895631bc2a9fa 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -110,11 +110,9 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, } // Gets the GPU name as it's known to LLVM for a given compute capability. If -// we see an unrecognized compute capability, we return "sm_30". +// we see an unrecognized compute capability, we return "sm_35". static string GetSmName(std::pair compute_capability) { static auto* m = new std::map, int>({ - {{3, 0}, 30}, - {{3, 2}, 32}, {{3, 5}, 35}, {{3, 7}, 37}, {{5, 0}, 50}, @@ -125,8 +123,9 @@ static string GetSmName(std::pair compute_capability) { {{6, 2}, 62}, {{7, 0}, 70}, {{7, 2}, 72}, + {{7, 5}, 75}, }); - int sm_version = 30; + int sm_version = 35; auto it = m->find(compute_capability); if (it != m->end()) { sm_version = it->second; diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 01fddcede64d1bb02ab89db5fc9524893c2d47a4..02e1207f377b8c28bf2566bee8cf3bcbc66794fb 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -67,7 +67,7 @@ int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, } int64 profit = 0; for (auto instr : instr2->operands()) { - if (!IsProfitableOperand(instr) || in_list.count(instr) == 0) { + if (!IsProfitableOperand(instr) || !in_list.contains(instr)) { continue; } profit += ShapeUtil::ByteSizeOf(instr->shape()); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index d16c87ba5c63aa582753fe949e9e39ee2d8b81e5..40b87b16a195564c9b98497f79a70f1db0539d87 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -628,8 +628,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) { p.1 = s32[1]{0} parameter(1) p.2 = f16[1,96,1024]{2,1,0} parameter(2) c.0 = s32[] constant(0) - pad = s32[3]{0} pad(p.1, c.0), padding=0_2 - ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0) } fusion.2 { @@ -638,7 +637,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) { p.2 = f16[1,96,1024]{2,1,0} parameter(2) c.0 = s32[] constant(0) pad = s32[3]{0} pad(p.1, c.0), padding=0_2 - ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0) } ENTRY entry { diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index cd369d55987b96eed2efb64ae0df6b3a76acb672..48f718b514cc9809d4100627f85af7aa05445d36 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -37,6 +37,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/convolution_group_converter.h" +#include "tensorflow/compiler/xla/service/dot_decomposer.h" +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h" @@ -78,6 +80,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/sort_simplifier.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -152,6 +155,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); + pipeline.AddPass(); pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( &pipeline, hlo_module->config().debug_options(), @@ -163,6 +167,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // We need a cost model for GPUs. Currently, do nothing. return false; }; + pipeline.AddPass(false); pipeline.AddPass( cost_model, /*convert_batch_groups_only=*/true); @@ -194,10 +199,10 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // elimination has to come after that pass. pipeline.AddPass(); - AlgebraicSimplifierOptions options( - [](const Shape&, const Shape&) { return false; }); + AlgebraicSimplifierOptions options; options.set_enable_permutation_sort_replacement(true); pass.AddPass(options); + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -266,10 +271,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. - AlgebraicSimplifierOptions options( - /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { - return true; - }); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); options.set_enable_permutation_sort_replacement(true); pipeline.AddPass>(options); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index 4775baf44aecfe6adaf2bf0d2791595436635b16..1dedbd3befce6e2ceb06126d83a061207a90dd8f 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -25,7 +26,7 @@ namespace xla { namespace gpu { bool StreamAssignment::HasStreamAssigned(const HloInstruction& hlo) const { - return hlo_to_stream_number_.count(&hlo); + return hlo_to_stream_number_.contains(&hlo); } int StreamAssignment::StreamNumberForHlo(const HloInstruction& hlo) const { @@ -98,10 +99,10 @@ int ComputeStreamToAssign( // greedy approach. First, we compute as forbidden_stream_numbers the // streams assigned to GEMMs that are concurrent with `hlo`. Then, we assign // `hlo` a different stream. - std::set forbidden_stream_numbers; + absl::flat_hash_set forbidden_stream_numbers; for (const auto* seen_gemm : seen_gemms) { int stream_num = stream_assignment.StreamNumberForHlo(*seen_gemm); - if (!forbidden_stream_numbers.count(stream_num) && + if (!forbidden_stream_numbers.contains(stream_num) && CanRunConcurrently(*seen_gemm, hlo, reachability)) { forbidden_stream_numbers.insert(stream_num); } @@ -109,7 +110,7 @@ int ComputeStreamToAssign( for (int stream_num = 0; stream_num < stream_assignment.StreamCount(); ++stream_num) { - if (!forbidden_stream_numbers.count(stream_num)) { + if (!forbidden_stream_numbers.contains(stream_num)) { return stream_num; } } diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index d798b31643782eb25bba08227e29903ec0e7a597..d8bd9f7f6df48fe2faf510b369b99b6cd2173608 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -47,6 +47,21 @@ cc_library( ], ) +tf_cc_test( + name = "gpu_buffer_assignment_test", + srcs = ["gpu_buffer_assignment_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "gpu_copy_test", srcs = ["gpu_copy_test.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_buffer_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_buffer_assignment_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1335d73494100788f3ffe1bd0f5eb200de79cb21 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_buffer_assignment_test.cc @@ -0,0 +1,90 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuBufferAssignmentTest : public GpuCodegenTest { + public: + HloModuleConfig ConfigWithoutHloPasses() { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + // Disable layout_assignment to use the preassigned layouts. + debug_options.xla_disable_all_hlo_passes(); + config.set_debug_options(debug_options); + return config; + } +}; + +TEST_F(GpuBufferAssignmentTest, InstructionNameWithHyphenSanitized) { + const char *const kHloString = R"( + HloModule HyphenInInstructionName + ENTRY kernelEntry { + ROOT equal-to = s32[2]{0} constant({42, 73}) + })"; + + // Check that '-' in the instruction name is changed to '_'. + auto hlo_module = ParseHloString(kHloString).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK: buffer_for_equal_to = +)", + /*match_optimized_ir=*/true); + + // TODO(bixia): The run fails randomly. + // Check that the kernel runs correctly. + // EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + +TEST_F(GpuBufferAssignmentTest, BufferSanitizedNameCollisionResolved) { + const char *const kHloString = R"( + HloModule BufferSanitizedName + ENTRY kernelEntry { + equal.to = s32[2]{0} constant({42, 73}) + equal-to = s32[2]{0} constant({67, 3}) + ROOT add = s32[2]{0} add(equal.to, equal-to) + })"; + + // Turn of Hlo passes to prevent constant folding. + // + // Check that '-' and '.' in the instruction name are changed to '_', and + // name collision is resolved by LLVM. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutHloPasses()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK: buffer_for_equal_to = +; CHECK: buffer_for_equal_to1 = +)", + /*match_optimized_ir=*/false); + + // TODO(bixia): There is another bug that prevents this from running + // correctly. + // Check that the kernel runs correctly. + // EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index a302b582ede3723acd118d2e4a4bb3efdf7a4d0b..869724db601b2d5e4ed6d3c7bf3e10a748433146 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -65,7 +65,7 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy -; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); @@ -91,7 +91,7 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy -; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK-NOT: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); @@ -118,7 +118,7 @@ TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion -; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); @@ -152,7 +152,7 @@ TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion -; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); @@ -187,13 +187,13 @@ TEST_F(GpuKernelTilingTest, CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion -; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK-NOT: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); } -TEST_F(GpuKernelTilingTest, FusionTransposeWithReverseNotTiled) { +TEST_F(GpuKernelTilingTest, TransposedInputWithUserReverseNotTiled) { const char *const kHloString = R"( HloModule FusionTransposeWithReverseNotTiled fused_computation.1 { @@ -214,12 +214,203 @@ TEST_F(GpuKernelTilingTest, FusionTransposeWithReverseNotTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion -; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK-NOT: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); } +TEST_F(GpuKernelTilingTest, TransposedInputWithUserBitcastNotTiled) { + const char *const kHloString = R"( + HloModule TransposedInputWithUserBitcast + + fused_computation { + param_0 = f32[20,20]{1,0} parameter(0) + ROOT bitcast = f32[20,20]{0,1} bitcast(param_0) + } + + ENTRY kernel_entry { + parameter.0 = f32[20,20]{1,0} parameter(0) + ROOT fusion = f32[20,20]{0,1} fusion(parameter.0), + kind=kLoop, calls=fused_computation + })"; + + // Check that a call to llvm.nvvm.barrier0 is not generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK-NOT: call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); + + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + +TEST_F(GpuKernelTilingTest, TransposedInputWithoutUnsafeUseTiled) { + const char *const kHloString = R"( + HloModule TwoTransposedInputs + + fused_computation { + param_0 = f32[64,64]{1,0} parameter(0) + param_1 = f32[64,64]{1,0} parameter(1) + bitcast = f32[64,64]{0,1} bitcast(param_0) + copy = f32[64,64]{0,1} copy(param_1) + ROOT tuple = (f32[64,64]{0,1}, f32[64,64]{0,1}) tuple(bitcast, copy) + } + + ENTRY kernel_entry { + parameter.0 = f32[64,64]{1,0} parameter(0) + parameter.1 = f32[64,64]{1,0} parameter(1) + ROOT fusion = (f32[64,64]{0,1}, f32[64,64]{0,1}) + fusion(parameter.0, parameter.1), + kind=kLoop, calls=fused_computation + })"; + + // Check that a call to llvm.nvvm.barrier0 is generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + +TEST_F(GpuKernelTilingTest, ColumnReductionWithPowerOf2OutputElementsUnrolled) { + const char *const kHloString = R"( + HloModule column_reduce_powerof2 + + reduction { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + ENTRY kernel_entry { + constant0 = f32[] constant(0) + arg1 = f16[1024,512]{1,0} parameter(0) + arg1_conv = f32[1024,512]{1,0} convert(arg1) + ROOT reduce = f32[512]{0} reduce(arg1_conv, constant0), dimensions={0}, to_apply=reduction + })"; + + // Check that two calls to llvm.nvvm.atomic are generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK-NOT: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: } +)", + /*match_optimized_ir=*/true); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); +} + +TEST_F(GpuKernelTilingTest, + ColumnReductionWithInputLargerThenReduceInputNotUnrolled) { + const char *const kHloString = R"( + HloModule larger_than_reduce_input_parameter + + reduction22 { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + fused_computation { + constant0 = f32[] constant(0) + arg.1 = f16[1024,512]{1,0} parameter(0) + arg.2 = f16[1027,513]{1,0} parameter(1) + arg1.conv = f32[1024,512]{1,0} convert(arg.1) + arg2.conv = f32[1027,513]{1,0} convert(arg.2) + slice2 = f32[1024,512]{1,0} slice(arg2.conv), slice={[2:1026], [1:513]} + add2 = f32[1024,512]{1,0} add(arg1.conv, slice2) + ROOT reduce = f32[512]{0} reduce(add2, constant0), dimensions={0}, + to_apply=reduction22 + } + + ENTRY kernel_entry { + arg1 = f16[1024,512]{1,0} parameter(0) + arg2 = f16[1027,513]{1,0} parameter(1) + ROOT fusion = f32[512]{0} fusion(arg1, arg2), kind=kInput, + calls=fused_computation + })"; + + // Check that one call to llvm.nvvm.atomic is generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK-NOT: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: } +)", + /*match_optimized_ir=*/true); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); +} + +TEST_F(GpuKernelTilingTest, ColumnReductionMOFUnrolled) { + const char *const kHloString = R"( + HloModule column_reduce_powerof2_mof + + reduction22 { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + fused_computation { + constant0 = f32[] constant(0) + arg.1 = f16[1024,512]{1,0} parameter(0) + arg.2 = f16[1024,512]{1,0} parameter(1) + arg1.conv = f32[1024,512]{1,0} convert(arg.1) + arg2.conv = f32[1024,512]{1,0} convert(arg.2) + reduce1 = f32[512]{0} reduce(arg1.conv, constant0), dimensions={0}, + to_apply=reduction22 + reduce2 = f32[512]{0} reduce(arg2.conv, constant0), dimensions={0}, + to_apply=reduction22 + add = f32[1024,512]{1,0} add(arg1.conv, arg2.conv) + ROOT tuple = (f32[512]{0}, f32[512]{0}, f32[1024,512]{1,0}) + tuple(reduce1, reduce2, add) + } + + ENTRY kernel_entry { + arg1 = f16[1024,512]{1,0} parameter(0) + arg2 = f16[1024,512]{1,0} parameter(1) + ROOT fusion = (f32[512]{0}, f32[512]{0}, f32[1024,512]{1,0}) + fusion(arg1, arg2), kind=kInput, calls=fused_computation + })"; + + // Check that four calls to llvm.nvvm.atomic are generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK-NOT: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: } +)", + /*match_optimized_ir=*/true); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); +} } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc index f8120a5fa00ce38644cd85c54d5ef65701be1eda..f91a22d482bc8bc046977870a7a4d18ca1acde68 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc @@ -43,7 +43,7 @@ class InfeedTest : public ClientLibraryTestBase { ASSERT_IS_OK(client_->TransferToInfeed(literal)); XlaBuilder builder(TestName()); Infeed(&builder, literal.shape()); - if (ShapeUtil::IsTuple(literal.shape())) { + if (literal.shape().IsTuple()) { // TODO(b/30609564): Use ComputeAndCompareLiteral instead. ComputeAndCompareTuple(&builder, literal, {}); } else { diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc index 6b2d76764a077dc6cfa3f9ddc6e525ab330323be..25bad67bab9375559c431466571c62acd0452b01 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -14,17 +14,19 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/map_util.h" namespace xla { namespace gpu { void ThunkSchedule::AddDependenciesOnTransitiveOperands( const Thunk& thunk, const HloInstruction& operand, - const std::unordered_map& hlo_to_thunk) { - if (hlo_to_thunk.count(&operand)) { + const absl::flat_hash_map& hlo_to_thunk) { + if (hlo_to_thunk.contains(&operand)) { // If `operand` is mapped to a thunk, adds `operand` to `thunk`'s dependency // list if `operand` is assigned to a different stream. As an optimization, // we skip `operand`'s operands because `operand` depends on them already. @@ -48,14 +50,14 @@ ThunkSchedule::ThunkSchedule( const std::vector& hlo_total_order) : thunks_(std::move(thunks)), stream_assignment_(std::move(stream_assignment)) { - std::unordered_map hlo_to_thunk; + absl::flat_hash_map hlo_to_thunk; for (const auto& thunk : *thunks_) { InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get()); } for (HloInstruction* hlo : hlo_total_order) { - if (hlo_to_thunk.count(hlo)) { - thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo)); + if (Thunk** thunk = tensorflow::gtl::FindOrNull(hlo_to_thunk, hlo)) { + thunk_total_order_.push_back(*thunk); } } @@ -106,7 +108,7 @@ void ThunkSchedule::RemoveRedundantDependencyEdges() { // redundant dependency edge. Array2D last_dependency(stream_count, stream_count, -1); for (const Thunk* dst : thunk_total_order_) { - if (!depends_on_.count(dst)) { + if (!depends_on_.contains(dst)) { continue; } @@ -134,7 +136,7 @@ void ThunkSchedule::RemoveRedundantDependencyEdges() { const std::list& ThunkSchedule::DependsOn( const Thunk* thunk) const { - if (depends_on_.count(thunk)) { + if (depends_on_.contains(thunk)) { return FindOrDie(depends_on_, thunk); } else { return empty_thunk_list_; diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h index 43b628a1baf0e79a3197f3cfad3547991642eaed..549378debd52417252724a5d8a6f4d24f2ad0369 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -54,7 +56,9 @@ class ThunkSchedule { // Thunks that `thunk` depends on. const std::list& DependsOn(const Thunk* thunk) const; // Whether `thunk` is depended by another thunk. - bool Depended(const Thunk* thunk) const { return depended_by_.count(thunk); } + bool Depended(const Thunk* thunk) const { + return depended_by_.contains(thunk); + } // Delegates to StreamAssignment. int StreamCount() const { return stream_assignment_->StreamCount(); } @@ -75,13 +79,13 @@ class ThunkSchedule { // thunk.hlo_instruction(). void AddDependenciesOnTransitiveOperands( const Thunk& thunk, const HloInstruction& operand, - const std::unordered_map& hlo_to_thunk); + const absl::flat_hash_map& hlo_to_thunk); std::unique_ptr thunks_; std::vector thunk_total_order_; - std::unordered_map> depends_on_; - std::set depended_by_; + absl::flat_hash_map> depends_on_; + absl::flat_hash_set depended_by_; std::list empty_thunk_list_; std::unique_ptr stream_assignment_; diff --git a/tensorflow/compiler/xla/service/gpu/xfeed_queue.h b/tensorflow/compiler/xla/service/gpu/xfeed_queue.h index dd46ff433ba0ad6bfa3999b96845fdaebe148aca..167c038420a64d9fa29746ed3fe349620e08e6ff 100644 --- a/tensorflow/compiler/xla/service/gpu/xfeed_queue.h +++ b/tensorflow/compiler/xla/service/gpu/xfeed_queue.h @@ -47,6 +47,10 @@ class XfeedQueue { // Blocks until the queue is non-empty, then returns the buffer at the head of // the queue. BufferType BlockingGetNextDestination() { + for (const auto& callback : before_get_next_dest_callbacks_) { + callback(); + } + bool became_empty; BufferType current_buffer; { @@ -69,6 +73,10 @@ class XfeedQueue { void RegisterOnEmptyCallback(std::function callback) { on_empty_callbacks_.push_back(std::move(callback)); } + void RegisterBeforeGetNextDestinationCallback( + std::function callback) { + before_get_next_dest_callbacks_.push_back(std::move(callback)); + } private: tensorflow::mutex mu_; @@ -82,6 +90,11 @@ class XfeedQueue { // List of callbacks which will be called when 'enqueued_buffers_' becomes // empty. std::vector> on_empty_callbacks_; + + // List of callbacks which will be called before BlockingGetNextDestination() + // is called. This lets you e.g. call EnqueueDestination() for each call to + // BlockingGetNextDestination(). + std::vector> before_get_next_dest_callbacks_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 9220865867b770eebfb1ada8f31a5d24693a4b8d..4fca981c6a59cdb91a997e6a887fd26472c1a10a 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -199,7 +199,7 @@ Status HeapSimulator::RunComputation( // If the buffer has no users and isn't an entry parameter or output, it // must be a dead value. - if (live_buffers.count(buffer) == 0) { + if (!live_buffers.contains(buffer)) { dead_buffers_to_free.push_back(buffer); } } @@ -225,10 +225,10 @@ Status HeapSimulator::RunComputation( } } // Sort to get a deterministic iteration order. - std::sort(operand_buffers_to_free.begin(), operand_buffers_to_free.end(), - [](const BufferValue* x, const BufferValue* y) { - return x->id() < y->id(); - }); + absl::c_sort(operand_buffers_to_free, + [](const BufferValue* x, const BufferValue* y) { + return x->id() < y->id(); + }); // Allocate buffers defined by this instruction. This is the latest point // that we can allocate; right before the buffer is first used. This must @@ -253,7 +253,7 @@ Status HeapSimulator::RunComputation( bool shared = false; if (options_.may_reuse_operand_buffers) { for (const BufferValue* operand_buffer : operand_buffers_to_free) { - if (reused_buffers.count(operand_buffer) != 0) { + if (reused_buffers.contains(operand_buffer)) { continue; } if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && @@ -335,10 +335,9 @@ Status HeapSimulator::RunComputation( to_free.push_back(buffer); } - std::sort(to_free.begin(), to_free.end(), - [](const BufferValue* x, const BufferValue* y) { - return x->id() < y->id(); - }); + absl::c_sort(to_free, [](const BufferValue* x, const BufferValue* y) { + return x->id() < y->id(); + }); for (const BufferValue* buffer : to_free) { VLOG(3) << "Freeing pending: " << buffer->ToString(); Free(buffer, root); @@ -374,15 +373,15 @@ bool HeapSimulator::IgnoreBuffer(const BufferValue* buffer) const { return true; } return options_.buffers_to_assign != nullptr && - options_.buffers_to_assign->count(buffer) == 0; + !options_.buffers_to_assign->contains(buffer); } // Alloc always calls the underlying heap algorithm. void HeapSimulator::Alloc(const BufferValue* buffer, const HloInstruction* instruction) { - CHECK(allocated_buffers_.count(buffer) == 0) + CHECK(!allocated_buffers_.contains(buffer)) << "Alloc called on allocated buffer: " << *buffer; - CHECK(freed_buffers_.count(buffer) == 0) + CHECK(!freed_buffers_.contains(buffer)) << "Alloc called on freed buffer: " << *buffer; allocated_buffers_.insert(buffer); @@ -411,9 +410,9 @@ void HeapSimulator::Free(const BufferValue* buffer, buffer = group->canonical; } - CHECK(allocated_buffers_.count(buffer) > 0) + CHECK(allocated_buffers_.contains(buffer)) << "Free called on non-allocated buffer: " << *buffer; - CHECK(freed_buffers_.count(buffer) == 0) + CHECK(!freed_buffers_.contains(buffer)) << "Free called on freed buffer: " << *buffer; freed_buffers_.insert(buffer); @@ -433,11 +432,11 @@ void HeapSimulator::ShareBuffer(const BufferValue* buffer, const HloInstruction* instruction) { CHECK_LE(size_fn_(*buffer), size_fn_(*shared)) << "ShareBuffer oversized buffer" << *buffer << " shared: " << *shared; - CHECK(allocated_buffers_.count(buffer) == 0) + CHECK(!allocated_buffers_.contains(buffer)) << "ShareBuffer called on allocated buffer: " << *buffer; - CHECK(freed_buffers_.count(buffer) == 0) + CHECK(!freed_buffers_.contains(buffer)) << "ShareBuffer called on freed buffer: " << *buffer; - CHECK(freed_buffers_.count(shared) == 0) + CHECK(!freed_buffers_.contains(shared)) << "ShareBuffer called on freed shared buffer: " << *shared; const BufferValue* canonical = nullptr; @@ -452,7 +451,7 @@ void HeapSimulator::ShareBuffer(const BufferValue* buffer, } else { // The 'shared' buffer doesn't have a group; it must be the canonical. Add // both 'buffer' and 'shared' to a new group. - CHECK(allocated_buffers_.count(shared) > 0) + CHECK(allocated_buffers_.contains(shared)) << "ShareBuffer called on non-allocated shared buffer: " << *shared; auto group = std::make_shared(); canonical = shared; @@ -596,7 +595,7 @@ void DecreasingSizeRunsHeap::CallAndDrainRun() { } // Call ops in the run sorted by decreasing size, breaking ties by buffer id. - std::sort(run_.begin(), run_.end(), [](const Op& a, const Op& b) { + absl::c_sort(run_, [](const Op& a, const Op& b) { if (a.size != b.size) { return a.size > b.size; } @@ -866,23 +865,23 @@ HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() { for (auto& entry : buffer_intervals_) { sorted_buffer_intervals.push_back(entry.second); } - std::sort(sorted_buffer_intervals.begin(), sorted_buffer_intervals.end(), - [](const BufferInterval& x, const BufferInterval& y) { - if (x.size != y.size) { - return x.size > y.size; - } - if (x.end - x.start != y.end - y.start) { - return x.end - x.start > y.end - y.start; - } - return x.buffer->id() < y.buffer->id(); - }); + absl::c_sort(sorted_buffer_intervals, + [](const BufferInterval& x, const BufferInterval& y) { + if (x.size != y.size) { + return x.size > y.size; + } + if (x.end - x.start != y.end - y.start) { + return x.end - x.start > y.end - y.start; + } + return x.buffer->id() < y.buffer->id(); + }); BufferIntervalTree interval_tree(sorted_buffer_intervals.size()); for (auto& buffer_interval : sorted_buffer_intervals) { auto chunks_overlapping_in_time = interval_tree.ChunksOverlappingInTime( buffer_interval.start, buffer_interval.end); - std::sort( - chunks_overlapping_in_time.begin(), chunks_overlapping_in_time.end(), + absl::c_sort( + chunks_overlapping_in_time, [](const Chunk& x, const Chunk& y) { return x.offset < y.offset; }); // Find the minimum free chunk that can hold this buffer. diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index dbbf43082f2c1d21f5ef42f53804bf0969903a58..3e0631aeb4aa374cb5748650e1c7529e26e10b34 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -158,7 +158,7 @@ class HeapSimulator { void FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, const BufferValue* buffer, const HloInstruction* instruction, - const BufferValue* shared_with_canonical); + const BufferValue* share_with_canonical); // Counterintuitive: the algorithm_ itself can be a NoFragmentationStatsHeap, // in which case we are calculating the same allocs/frees twice in the diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 9b50f1ca5b5365463f32106fc005ef2c63f2e37a..263b42a29dbb0dbc0fb6eca7968674ff242f45ed 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -229,6 +229,18 @@ message HloScheduleProto { } message HloInputOutputAliasProto { + enum Kind { + // Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3 + // behavior and missing has_*() APIs. + UNDEFINED_ALIAS = 0; + // An alias setup by the user as must alias. A use setting USER_ALIAS is + // expecting the designed output to be dropped over the given input + // parameter number+index. + USER_ALIAS = 1; + // An alias setup by the compiler as part of its optimizations. + SYSTEM_ALIAS = 2; + } + // The following proto describes a pair of aliased an input // (described by parameter number and a ShapeIndex of the parameter) // and an output (described by a ShapeIndex of the root @@ -249,6 +261,8 @@ message HloInputOutputAliasProto { int64 parameter_number = 2; // ShapeIndex of the parameter instruction. repeated int64 parameter_shape_index = 3; + // The kind of alias to be setup. + Kind kind = 4; } repeated AliasEntryProto entries = 1; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index cf8e6594cbe5ffd28ca75dd5006e8817f1e8581c..e511f1951c5dd07ebb64fa38fd5b7f6a0e87b429 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -117,7 +117,7 @@ class BufferValueMap { for (const auto& pair : buffers_) { buffer_numbers.push_back(pair.first); } - std::sort(buffer_numbers.begin(), buffer_numbers.end()); + absl::c_sort(buffer_numbers); return buffer_numbers; } @@ -176,13 +176,12 @@ class BufferValueMap { const HloValue& value, std::vector* aliased_buffers) { // Get parameter value from an aliased_input object. const auto get_parameter_value = - [this](const std::pair& aliased_input) + [this](const HloInputOutputAliasConfig::Alias& aliased_input) -> const HloValue& { - int64 param_number = aliased_input.first; - const ShapeIndex& param_index = aliased_input.second; return dataflow_.GetUniqueValueAt( - module_->entry_computation()->parameter_instruction(param_number), - param_index); + module_->entry_computation()->parameter_instruction( + aliased_input.parameter_number), + aliased_input.parameter_index); }; // If the value shows up in a root instruction, alias it with parameter @@ -319,7 +318,7 @@ class BufferValueMap { ComputeWhileAliasedBuffers(value, &aliased_buffers); ComputeConditionalAliasedBuffers(value, &aliased_buffers); // Uniquify aliased buffers. - std::sort(aliased_buffers.begin(), aliased_buffers.end()); + absl::c_sort(aliased_buffers); aliased_buffers.erase( std::unique(aliased_buffers.begin(), aliased_buffers.end()), aliased_buffers.end()); @@ -367,7 +366,7 @@ std::vector HloAliasAnalysis::ComputeBuffersAt( } // Sort and uniquify vector before returning. - std::sort(buffers.begin(), buffers.end(), HloBuffer::IdLessThan); + absl::c_sort(buffers, HloBuffer::IdLessThan); buffers.erase(std::unique(buffers.begin(), buffers.end()), buffers.end()); return buffers; @@ -430,8 +429,7 @@ Status HloAliasAnalysis::Verify() const { for (const auto& pair : value_to_buffer_) { const HloValue* value = pair.first; const HloBuffer& buffer = *pair.second; - TF_RET_CHECK(std::find(buffer.values().begin(), buffer.values().end(), - value) != buffer.values().end()); + TF_RET_CHECK(absl::c_linear_search(buffer.values(), value)); } for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) { @@ -457,7 +455,7 @@ string HloAliasAnalysis::ToString() const { for (const HloComputation* computation : module_->computations()) { for (const HloInstruction* instruction : computation->instructions()) { StrAppend(&out, " ", instruction->name(), ":\n"); - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { ShapeUtil::ForEachSubshape( instruction->shape(), [&out, &instruction, this](const Shape&, const ShapeIndex& index) { @@ -515,7 +513,7 @@ StatusOr> HloAliasAnalysis::Run( auto& value_set = buffer_map.GetValuesInBuffer(buffer_number); std::vector sorted_values(value_set.begin(), value_set.end()); - std::sort(sorted_values.begin(), sorted_values.end(), HloValue::IdLessThan); + absl::c_sort(sorted_values, HloValue::IdLessThan); alias_analysis->buffers_.emplace_back(next_id++, sorted_values); for (const HloValue* value : sorted_values) { alias_analysis->value_to_buffer_[value] = @@ -533,11 +531,11 @@ bool HloAliasAnalysis::HasLiveRangeInterference( const HloOrdering& ordering) const { for (const HloBuffer& buffer : buffers()) { CHECK(!buffer.values().empty()); - if (ShapeUtil::IsToken(buffer.values().front()->shape())) { + if (buffer.values().front()->shape().IsToken()) { // Tokens have no on-device representation and cannot interfere. for (const HloValue* value : buffer.values()) { // If one of the values is a token, all values must be a token. - DCHECK(ShapeUtil::IsToken(value->shape())); + DCHECK(value->shape().IsToken()); } continue; } @@ -547,16 +545,15 @@ bool HloAliasAnalysis::HasLiveRangeInterference( // tie-break using value ID. The tie-break is necessary because we need a // strict weak order for std::sort. std::vector values = buffer.values(); - std::sort(values.begin(), values.end(), - [&ordering](const HloValue* a, const HloValue* b) { - if (ordering.IsDefinedBefore(*a, *b)) { - return true; - } else if (ordering.IsDefinedBefore(*b, *a)) { - return false; - } else { - return a->id() < b->id(); - } - }); + absl::c_sort(values, [&ordering](const HloValue* a, const HloValue* b) { + if (ordering.IsDefinedBefore(*a, *b)) { + return true; + } else if (ordering.IsDefinedBefore(*b, *a)) { + return false; + } else { + return a->id() < b->id(); + } + }); // Walk through the ordered vector of values. First verify that the values // are totally ordered with respect to 'ordering', then check that no diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 7e6150e94153cd15463725e862ce1b8593f2c991..b6dbf07959c541bceaa8eda5a0101503970ee832 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -238,13 +238,16 @@ TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); module_->AddEntryComputation(builder.Build()); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); // Cannot alias an output twice. ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -279,13 +282,16 @@ TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); module_->AddEntryComputation(builder.Build()); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); // Cannot alias an output twice. ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -365,9 +371,11 @@ TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) { builder.AddInstruction(HloInstruction::CreateTuple({negate_1, negate_2})); module_->AddEntryComputation(builder.Build()); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); const HloAliasAnalysis& analysis = RunAnalysis(); diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc index 9c3aa0e64d119c2560f4955d0bcb492519fa52a2..32e48651b30bace4723169935d1f10dd7d7bfec3 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.cc +++ b/tensorflow/compiler/xla/service/hlo_buffer.cc @@ -49,7 +49,7 @@ std::vector HloBuffer::ComputePositions() const { value->positions().end()); } // Remove duplicates and sort positions. - std::sort(positions.begin(), positions.end()); + absl::c_sort(positions); positions.erase(std::unique(positions.begin(), positions.end()), positions.end()); return positions; diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 75630307186ba42f711a85454d73722533e59358..40fe91398be33f5681e1389e1b6fadcbd87487bb 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -207,14 +207,14 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( TF_RET_CHECK(instruction->user_count() == 0); TF_RET_CHECK(IsRemovable(instruction)) << "Cannot remove instruction: " << instruction->ToString(); - std::unordered_set removed; + absl::flat_hash_set removed; std::queue worklist; worklist.push(instruction); while (!worklist.empty()) { HloInstruction* item = worklist.front(); worklist.pop(); - if (removed.count(item) != 0 || item->user_count() != 0 || + if (removed.contains(item) || item->user_count() != 0 || item == root_instruction() || !IsRemovable(item) || (item->HasSideEffect() && item != instruction)) { continue; @@ -531,11 +531,10 @@ HloComputation::CreateFromProto( HloInstruction* root = instruction_map.at(proto.root_id()); // Sort the instructions in the proto id's order. - std::sort(instructions.begin(), instructions.end(), - [&](const std::unique_ptr& a, - const std::unique_ptr& b) { - return to_proto_id[a.get()] < to_proto_id[b.get()]; - }); + absl::c_sort(instructions, [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); TF_RETURN_IF_ERROR([&]() -> Status { std::vector parameters_seen(parameter_count); @@ -600,7 +599,7 @@ StatusOr HloComputation::DeepCopyHelper( const std::function< HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index, HloComputation* computation)>& copy_leaf) { - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { std::vector elements; for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); i++) { @@ -617,14 +616,14 @@ StatusOr HloComputation::DeepCopyHelper( } return AddInstruction(HloInstruction::CreateTuple(elements)); } - if (ShapeUtil::IsToken(instruction->shape())) { + if (instruction->shape().IsToken()) { // Tokens have no on-device representation and cannot be copied. Pass // through transparently. return instruction; } // Array shape. - TF_RET_CHECK(ShapeUtil::IsArray(instruction->shape())); + TF_RET_CHECK(instruction->shape().IsArray()); return copy_leaf(instruction, *index, this); } @@ -694,22 +693,36 @@ bool HloComputation::operator==(const HloComputation& other) const { if (this == &other) { return true; } - std::set> visited; - std::function eq = - [&visited, &eq](const HloInstruction* a, const HloInstruction* b) { - // If are visited but not identical, the recursion should have - // been aborted. So, if are visited at this point, they must be - // identical. - if (visited.count(std::make_pair(a, b)) > 0) { - return true; - } - visited.emplace(a, b); - return a->Identical( - *b, eq, [](const HloComputation* a, const HloComputation* b) { - return *a == *b; - }); - }; - return eq(root_instruction(), other.root_instruction()); + absl::flat_hash_set> + visited; + std::vector> worklist; + + worklist.push_back({root_instruction(), other.root_instruction()}); + + while (!worklist.empty()) { + auto pair = worklist.back(); + worklist.pop_back(); + + if (visited.contains(pair)) { + continue; + } + visited.emplace(pair); + // TODO(b/123082518): Avoid recursively invoking == becasue it may + // cause a stack overflow with deeply nested subcomputations. + bool identical_ignoring_operands = pair.first->Identical( + *pair.second, + [](const HloInstruction*, const HloInstruction*) { return true; }, + [](const HloComputation* a, const HloComputation* b) { + return *a == *b; + }); + if (!identical_ignoring_operands) { + return false; + } + for (size_t i = 0; i < pair.first->operands().size(); ++i) { + worklist.push_back({pair.first->operand(i), pair.second->operand(i)}); + } + } + return true; } Status HloComputation::ReplaceWithNewInstruction( @@ -799,17 +812,16 @@ Status HloComputation::AcceptOrdered( absl::Span order) const { VLOG(3) << "Accepting visitor with order."; for (HloInstruction* root : CollectUnreachableRoots()) { - TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end()) - << root->ToString(); + TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString(); } TF_RET_CHECK(order.size() == instruction_count()); - std::unordered_set visited; + absl::flat_hash_set visited; for (const HloInstruction* instruction : order) { VLOG(3) << "Visiting ordered: " << instruction->ToString(); - TF_RET_CHECK(instruction_iterators_.count(instruction) == 1) + TF_RET_CHECK(instruction_iterators_.contains(instruction)) << "Instruction " << instruction->name() << " is not in computation " << name(); - TF_RET_CHECK(visited.count(instruction) == 0) + TF_RET_CHECK(!visited.contains(instruction)) << "Instruction " << instruction->name() << " appears more than once in order"; HloInstruction* mutable_instruction = @@ -845,29 +857,31 @@ Status HloComputation::Accept( std::unique_ptr HloComputation::Clone( const string& suffix, HloCloneContext* context) { return CloneWithReplacements( - /*replacements=*/std::unordered_map>(), - context, suffix); + /*replacements=*/absl::flat_hash_map>(), + /*extra_parameters=*/{}, context, suffix); } std::unique_ptr HloComputation::CloneWithReplacementPairs( std::pair> r1, HloCloneContext* context, const string& suffix) { - std::unordered_map> + absl::flat_hash_map> replacements; replacements.emplace(std::move(r1)); - return CloneWithReplacements(std::move(replacements), context, suffix); + return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{}, + context, suffix); } std::unique_ptr HloComputation::CloneWithReplacementPairs( std::pair> r1, std::pair> r2, HloCloneContext* context, const string& suffix) { - std::unordered_map> + absl::flat_hash_map> replacements; replacements.emplace(std::move(r1)); replacements.emplace(std::move(r2)); - return CloneWithReplacements(std::move(replacements), context, suffix); + return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{}, + context, suffix); } std::unique_ptr HloComputation::CloneWithReplacementPairs( @@ -875,17 +889,19 @@ std::unique_ptr HloComputation::CloneWithReplacementPairs( std::pair> r2, std::pair> r3, HloCloneContext* context, const string& suffix) { - std::unordered_map> + absl::flat_hash_map> replacements; replacements.emplace(std::move(r1)); replacements.emplace(std::move(r2)); replacements.emplace(std::move(r3)); - return CloneWithReplacements(std::move(replacements), context, suffix); + return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{}, + context, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( - std::unordered_map> + absl::flat_hash_map> replacements, + absl::Span extra_parameters, HloCloneContext* context, const string& suffix) { std::unique_ptr context_ptr; if (context == nullptr) { @@ -951,6 +967,12 @@ std::unique_ptr HloComputation::CloneWithReplacements( } std::vector> instructions; + // First add the extra parameters to 'instructions'. + for (const auto& instr : extra_parameters) { + CHECK_EQ(instr->opcode(), HloOpcode::kParameter) + << "Only parameter instructions are allowed in 'extra_parameters'"; + instructions.emplace_back(instr->Clone()); + } for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index a0ccbc583f8c409f29d31756fcc1fa1b4af7dc35..0cb9caddd089011f3e9a4473995847dc966dd402 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -323,11 +322,16 @@ class HloComputation { // that's not already in the computation, it's cloned and added to the new // computation. // + // 'extra_parameters' allows to specify additional parameters that should be + // added to the computation. + // // All relevant instructions are cloned, *including* unique_ptr in the // `replacements` map. std::unique_ptr CloneWithReplacements( - std::unordered_map> + absl::flat_hash_map> replacements, + absl::Span extra_parameters = {}, HloCloneContext* context = nullptr, const string& suffix = "clone"); // Convenience overloads for CloneWithReplacements. You want to do diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 0361c87428f6e4c031d95492a5bc782ad388e5b5..3b88e9745c27d6e1f2a46e5c83ac2e8bd8d05150 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -15,8 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include #include +#include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -226,7 +230,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { : computation_(computation) {} Status DefaultAction(HloInstruction* hlo_instruction) override { - EXPECT_EQ(0, visited_set_.count(hlo_instruction)); + EXPECT_FALSE(visited_set_.contains(hlo_instruction)); visited_set_.insert(hlo_instruction); last_visited_ = hlo_instruction; return Status::OK(); @@ -239,7 +243,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { } HloComputation* computation_; - std::set visited_set_; + absl::flat_hash_set visited_set_; int64 finish_visit_calls_ = 0; HloInstruction* last_visited_ = nullptr; }; @@ -491,6 +495,41 @@ TEST_F(HloComputationTest, CloneWithControlDependency) { EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add)); } +TEST_F(HloComputationTest, CloneWithReplacements) { + auto builder = HloComputation::Builder(TestName()); + Shape r0s64 = ShapeUtil::MakeShape(S64, {}); + Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + Shape r0u32 = ShapeUtil::MakeShape(U32, {}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "p.0.lhs")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32_, "p.0.rhs")); + auto param2 = + builder.AddInstruction(HloInstruction::CreateParameter(2, r0s64, "p.1")); + auto lt = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param0, param1)); + auto module = CreateNewVerifiedModule(); + auto computation = + module->AddEntryComputation(builder.Build(/*root_instruction=*/lt)); + absl::flat_hash_map> + replacements; + replacements.emplace(param2, + HloInstruction::CreateParameter(2, r0s32, "p.1")); + auto param3 = HloInstruction::CreateParameter(3, r0u32, "p.2"); + std::vector extra_parameters{param3.get()}; + auto clone = computation->CloneWithReplacements(std::move(replacements), + extra_parameters); + ASSERT_EQ(clone->num_parameters(), 4); + EXPECT_TRUE( + ShapeUtil::Equal(clone->parameter_instruction(0)->shape(), r0f32_)); + EXPECT_TRUE( + ShapeUtil::Equal(clone->parameter_instruction(1)->shape(), r0f32_)); + EXPECT_TRUE( + ShapeUtil::Equal(clone->parameter_instruction(2)->shape(), r0s32)); + EXPECT_TRUE( + ShapeUtil::Equal(clone->parameter_instruction(3)->shape(), r0u32)); +} + TEST_F(HloComputationTest, Stringification) { const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); @@ -606,5 +645,28 @@ TEST_F(HloComputationTest, StringificationCanonical) { EXPECT_EQ(computation->ToString(options), expected_computation2); } +std::unique_ptr MakeAddNComputation(int n) { + auto builder = HloComputation::Builder("add_n"); + auto result = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "x_value")); + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + for (int i = 0; i < n; ++i) { + result = builder.AddInstruction(HloInstruction::CreateBinary( + one->shape(), HloOpcode::kAdd, result, one)); + } + return builder.Build(); +} + +TEST_F(HloComputationTest, DeepEquality) { + auto computation_a = MakeAddNComputation(200000); + auto computation_b = MakeAddNComputation(200000); + EXPECT_TRUE(*computation_a == *computation_b); + + auto computation_c = MakeAddNComputation(199999); + EXPECT_FALSE(*computation_a == *computation_c); + EXPECT_FALSE(*computation_c == *computation_b); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 5e37883d3d8d5067bab873ac6b5f732e7360c5fa..e7ed858e8c5af83d08863d64a0aba162c75ed5cb 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -35,6 +35,34 @@ limitations under the License. namespace xla { +// Checks whether instr is or transitively contains an instruction that we +// shouldn't fold. +// +// Specifically, we don't fold kRng or kAfterAll instructions: +// +// - kRng is already marked as side-effecting and so is skipped elsewhere, but +// we check for it here. Even kRng weren't side-effecting and took an +// explicit seed, we *still* wouldn't want to constant-fold it, because the +// evaluator's handling of rng is not guaranteed to be identical to any +// particular backend's rng. +// +// - kAfterAll needs to be skipped because a kAfterAll op with no args can +// currently materialize a token "out of thin air". TODO(b/110532604): +// Remove this check once AfterAll requires at least one operand, in which +// case constant folding will be impossible. +static bool IsOrContainsIllegalInstr(const HloInstruction* instr) { + if (instr->opcode() == HloOpcode::kAfterAll || + instr->opcode() == HloOpcode::kRng) { + return true; + } + for (const HloComputation* c : instr->called_computations()) { + if (absl::c_any_of(c->instructions(), IsOrContainsIllegalInstr)) { + return true; + } + } + return false; +} + StatusOr HloConstantFolding::Run(HloModule* module) { // Limit the constant folding to 0 iterations to skip folding loops. This // retains the behavior from before while loop support in HloEvaluator and may @@ -52,25 +80,24 @@ StatusOr HloConstantFolding::Run(HloModule* module) { computation->root_instruction() != instruction) { continue; } - // Skip Constant, Parameter, Tuple, AfterAll operation. - // Tuple constants are not directly supported by any backends, hence - // folding Tuple is not useful and would in fact be expanded back into - // kTuple by Algebraic Simplifier. - // TODO(b/110532604): Enable AfterAll once AfterAll requires at least one - // operand in which case constant folding will be impossible and this - // special case is not necessary. - if (instruction->opcode() == HloOpcode::kParameter || - instruction->opcode() == HloOpcode::kConstant || - instruction->opcode() == HloOpcode::kTuple || - instruction->opcode() == HloOpcode::kAfterAll) { - continue; - } // Skip instructions with non-constant operands. if (!hlo_query::AllOperandsAreConstants(*instruction)) { continue; } + // Don't fold Constant, Parameter, and Tuple instructions. Tuple + // constants are not directly supported by any backends, hence folding + // Tuple is not useful and would in fact be expanded back into kTuple by + // Algebraic Simplifier. + // + // (We do allow folding subcomputations that contain these instructions.) + if (instruction->opcode() == HloOpcode::kParameter || + instruction->opcode() == HloOpcode::kConstant || + instruction->opcode() == HloOpcode::kTuple) { + continue; + } + // Broadcasts dramatically increase the size of constants, which is often // detrimental to performance and memory capacity, so do not fold // broadcasts. @@ -79,12 +106,23 @@ StatusOr HloConstantFolding::Run(HloModule* module) { continue; } + // Check for instructions that we can't fold even if they appear inside of + // a subcomputation (e.g. a kCall). + if (IsOrContainsIllegalInstr(instruction)) { + continue; + } + + // Don't constant-fold side-effecting instructions or instructions which + // contain side-effecting instructions. + if (instruction->HasSideEffect()) { + continue; + } + // Don't constant fold unless it's a net positive or the output is small. - if (ShapeUtil::IsArray(instruction->shape())) { + if (instruction->shape().IsArray()) { int64 elements_in_removed_operands = 0; for (HloInstruction* operand : instruction->operands()) { - if (operand->user_count() == 1 && - ShapeUtil::IsArray(operand->shape())) { + if (operand->user_count() == 1 && operand->shape().IsArray()) { elements_in_removed_operands += ShapeUtil::ElementsIn(operand->shape()); } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 92b748d813c3efef83ef0155f1d5d3c637ce2c57..4bdc980c9ac4fb79cde0242f407aea7057474b27 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -268,5 +268,51 @@ TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) { GmockMatch(m::Pad(m::Constant(), m::Constant()))); } +TEST_F(HloConstantFoldingTest, DontFoldSubcomputationContainingAfterAll) { + const char* const kModuleStr = R"( + HloModule test + + Fn { + tok = token[] after-all() + ROOT root = f32[10] iota(), iota_dimension=0 + } + + ENTRY entry { + ROOT call = f32[10] call(), to_apply=Fn + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + HloConstantFolding constant_folding; + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&constant_folding, module.get())); + EXPECT_FALSE(result); +} + +TEST_F(HloConstantFoldingTest, + DontFoldSubcomputationTransitivelyContainingRng) { + const char* const kModuleStr = R"( + HloModule test + + InnerFn { + c0 = f32[] constant(0) + c1 = f32[] constant(1) + ROOT rng = f32[10] rng(c0, c1), distribution=rng_uniform + } + + Fn { + ROOT fusion = f32[10] fusion(), kind=kLoop, calls=InnerFn + } + + ENTRY entry { + ROOT call = f32[10] call(), to_apply=Fn + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + HloConstantFolding constant_folding; + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&constant_folding, module.get())); + EXPECT_FALSE(result); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index cb431aed47f0a751a697305986a8a0c194ac966c..76fd402b2c25c8dbed7902a458cd3af44f89cbd1 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -237,24 +237,17 @@ Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) { Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { const Shape& lhs_shape = dot->operand(0)->shape(); - const Shape& rhs_shape = dot->operand(1)->shape(); + const Shape& dot_shape = dot->shape(); const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); // Count of elements along the reduction dimension (last dimension for the // rhs). - int64 reduction_width = - lhs_shape.dimensions(dnums.lhs_contracting_dimensions(0)); - // First divide by reduction width before multiplying by rhs elements to avoid - // overflow. - int64 fma_count; - if (reduction_width == 0) { - fma_count = 0; - } else { - fma_count = (ShapeUtil::ElementsIn(lhs_shape) / reduction_width) * - ShapeUtil::ElementsIn(rhs_shape); + int64 reduction_width = 1; + for (auto dim : dnums.lhs_contracting_dimensions()) { + reduction_width *= lhs_shape.dimensions(dim); } - - // We count an FMA operation as 2 floating point operations. - current_properties_[kFlopsKey] = kFmaFlops * fma_count; + // Each output elment requires reduction_widht FMA operations. + current_properties_[kFlopsKey] = + kFmaFlops * ShapeUtil::ElementsIn(dot_shape) * reduction_width; return Status::OK(); } @@ -292,7 +285,7 @@ Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) { // does not need to be multiplied by the number of input tensors - that's // already "priced in" by the sub-computation doing more work. auto arg = reduce->operand(0); - auto output_shape = ShapeUtil::IsArray(reduce->shape()) + auto output_shape = reduce->shape().IsArray() ? reduce->shape() : reduce->shape().tuple_shapes(0); int64 reduction_count = @@ -539,7 +532,7 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { Status HloCostAnalysis::HandleFft(const HloInstruction* fft) { auto real_shape = - ShapeUtil::IsTuple(fft->operand(0)->shape()) + fft->operand(0)->shape().IsTuple() ? ShapeUtil::GetTupleElementShape(fft->operand(0)->shape(), 0) : fft->operand(0)->shape(); constexpr int kFmaPerComplexMul = 4; @@ -561,7 +554,7 @@ Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) { double flops = 0.0; ShapeUtil::ForEachSubshape(crs->shape(), [&](const Shape& subshape, const ShapeIndex&) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { flops += ShapeUtil::ElementsIn(subshape); } }); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index ff32faf298dd1f04c5b769f2a88f76a7a1e18ae7..4d42770ba784ba15fae9518b40a75d8a2f038e66 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/compiler/xla/statusor.h" @@ -157,6 +158,87 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) { sizeof(float) * (10 * 5 + 5 * 30 + 10 * 30)); } +TEST_F(HloCostAnalysisTest, DotGeneral) { + XlaBuilder builder("matrix_multiply"); + auto lhs = + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5, 5}), "lhs"); + auto rhs = + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 5, 30}), "rhs"); + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(1); + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(0); + dnums.add_rhs_contracting_dimensions(1); + DotGeneral(lhs, rhs, dnums); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Check the number of computations returned from the analysis (1500 FMAs). + EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5); + + EXPECT_EQ(analysis.transcendental_count(), 0); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (10 * 5 * 5 + 5 * 5 * 30 + 10 * 30)); +} + +TEST_F(HloCostAnalysisTest, DotGeneral2) { + XlaBuilder builder("matrix_multiply"); + auto lhs = + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5, 5}), "lhs"); + auto rhs = + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 5, 30}), "rhs"); + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(1); + dnums.add_lhs_batch_dimensions(2); + dnums.add_rhs_contracting_dimensions(0); + dnums.add_rhs_batch_dimensions(1); + DotGeneral(lhs, rhs, dnums); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Check the number of computations returned from the analysis (1500 FMAs). + EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5); + + EXPECT_EQ(analysis.transcendental_count(), 0); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (10 * 5 * 5 + 5 * 5 * 30 + 5 * 10 * 30)); +} + +TEST_F(HloCostAnalysisTest, DotGeneral3) { + XlaBuilder builder("matrix_multiply"); + auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs"); + auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs"); + DotDimensionNumbers dnums; + DotGeneral(lhs, rhs, dnums); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Check the number of computations returned from the analysis (1500 FMAs). + EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5); + + EXPECT_EQ(analysis.transcendental_count(), 0); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (10 * 5 + 5 * 30 + 5 * 5 * 10 * 30)); +} + TEST_F(HloCostAnalysisTest, Map) { XlaBuilder builder("map"); auto input = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10}), "in"); @@ -529,7 +611,8 @@ TEST_F(HloCostAnalysisTest, DynamicSlice) { // Test the analysis on a slice. XlaBuilder builder("dynamic-slice"); auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x"); - DynamicSlice(x, ConstantR1(&builder, {1}), {1}); + DynamicSlice(x, absl::Span({ConstantR0(&builder, 1)}), + {1}); auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. @@ -545,7 +628,7 @@ TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) { XlaBuilder builder("dynamic-update-slice"); auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x"); DynamicUpdateSlice(x, ConstantR1(&builder, {1.0}), - ConstantR1(&builder, {1})); + absl::Span({ConstantR0(&builder, 1)})); auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index e41aeab19e49ddd4f2363746f0ff8ba1740139b3..d56f673455f9129b72e9d85eaf8cbf03cfee4302 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" @@ -105,12 +106,26 @@ StatusOr MakeDynamicSliceHlo( absl::Span slice_sizes) { HloComputation* computation = operand->parent(); CHECK_EQ(computation, start_indices->parent()); + int64 rank = start_indices->shape().dimensions(0); + std::vector scalar_start_indices; + for (int i = 0; i < rank; ++i) { + // TODO(b/118437727): Update callers to provide scalars directly. + auto slice = computation->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}), + start_indices, {i}, {i + 1}, {1})); + scalar_start_indices.push_back( + computation->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(start_indices->shape().element_type(), {}), + slice))); + } + std::vector scalar_start_indices_shapes( + rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {})); TF_ASSIGN_OR_RETURN( Shape dynamic_slice_shape, ShapeInference::InferDynamicSliceShape( - operand->shape(), start_indices->shape(), slice_sizes)); + operand->shape(), scalar_start_indices_shapes, slice_sizes)); return computation->AddInstruction(HloInstruction::CreateDynamicSlice( - dynamic_slice_shape, operand, start_indices, slice_sizes)); + dynamic_slice_shape, operand, scalar_start_indices, slice_sizes)); } StatusOr MakeDynamicUpdateSliceHlo( @@ -119,17 +134,31 @@ StatusOr MakeDynamicUpdateSliceHlo( HloComputation* computation = operand->parent(); CHECK_EQ(computation, update->parent()); CHECK_EQ(computation, start_indices->parent()); + int64 rank = start_indices->shape().dimensions(0); + std::vector scalar_start_indices; + for (int i = 0; i < rank; ++i) { + // TODO(b/118437727): Update callers to provide scalars directly. + auto slice = computation->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}), + start_indices, {i}, {i + 1}, {1})); + scalar_start_indices.push_back( + computation->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(start_indices->shape().element_type(), {}), + slice))); + } + std::vector scalar_start_indices_shapes( + rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {})); TF_ASSIGN_OR_RETURN( Shape dynamic_update_slice_shape, ShapeInference::InferDynamicUpdateSliceShape( - operand->shape(), update->shape(), start_indices->shape())); + operand->shape(), update->shape(), scalar_start_indices_shapes)); return computation->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - dynamic_update_slice_shape, operand, update, start_indices)); + dynamic_update_slice_shape, operand, update, scalar_start_indices)); } -StatusOr MakeBroadcastHlo( - HloInstruction* operand, absl::Span broadcast_dimensions, - absl::Span result_shape_bounds) { +HloInstruction* MakeBroadcastHlo(HloInstruction* operand, + absl::Span broadcast_dimensions, + absl::Span result_shape_bounds) { HloComputation* computation = operand->parent(); Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(), result_shape_bounds); @@ -189,8 +218,7 @@ StatusOr MakeMapHlo(absl::Span operands, for (const HloInstruction* operand : operands) { CHECK_EQ(computation, operand->parent()); operand_shapes.push_back(&operand->shape()); - max_operand_rank = - std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); + max_operand_rank = std::max(max_operand_rank, operand->shape().rank()); } std::vector map_dims(max_operand_rank); std::iota(map_dims.begin(), map_dims.end(), 0); @@ -207,7 +235,7 @@ StatusOr MakeReduceHlo(HloInstruction* operand, HloOpcode binary_opcode, HloModule* module) { DCHECK_NE(nullptr, module); - std::vector all_dims(ShapeUtil::Rank(operand->shape())); + std::vector all_dims(operand->shape().rank()); std::iota(all_dims.begin(), all_dims.end(), 0); auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {}); @@ -366,9 +394,9 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, return MakePadHlo(operand, zero, padding_config); } -StatusOr BroadcastZeros( - HloComputation* computation, PrimitiveType element_type, - absl::Span broadcast_dimensions) { +HloInstruction* BroadcastZeros(HloComputation* computation, + PrimitiveType element_type, + absl::Span broadcast_dimensions) { HloInstruction* zero = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 8e5ddbbd503a501bd493aec43a2ccd4db883ef0c..1c3174e9c89c16cb11589e7c0235bdf13eae6b85 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -82,9 +82,9 @@ StatusOr MakeDynamicUpdateSliceHlo( // Creates a broadcast HLO instruction and adds it to the computation containing // `operand`. -StatusOr MakeBroadcastHlo( - HloInstruction* operand, absl::Span broadcast_dimensions, - absl::Span result_shape_bounds); +HloInstruction* MakeBroadcastHlo(HloInstruction* operand, + absl::Span broadcast_dimensions, + absl::Span result_shape_bounds); // Creates a GetTupleElement HLO instruction and adds it to the computation // containing `operand`. @@ -198,9 +198,9 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, // Broadcasts a zero value of type `element_type` into a tensor with element // type `element_type` and dimension bounds `broadcast_dimensions`. The // broadcast instruction is emitted into `computation`. -StatusOr BroadcastZeros( - HloComputation* computation, PrimitiveType element_type, - absl::Span broadcast_dimensions); +HloInstruction* BroadcastZeros(HloComputation* computation, + PrimitiveType element_type, + absl::Span broadcast_dimensions); // Creates a HLO computation that takes arguments of type `domain` and produces // a value of type `range`. diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index aaa9ec60eb3c4e0159ed40b37d772e0973d306ec..6025e6a77941369f75ebaa98bdf0979669b3a03c 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -56,9 +56,9 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { entry_computation->set_root_instruction(first_1_dims_collapsed); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, - evaluator.Evaluate( - *module, {LiteralUtil::CreateR1({3, 4})})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR1({3, 4})})); CHECK_EQ(result_literal, LiteralUtil::CreateR1({3, 4})); } @@ -77,10 +77,9 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate( - *module, - {LiteralUtil::CreateR3( - {{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR3( + {{{1, 2}, {3, 4}, {5, 6}}, + {{-1, -2}, {-3, -4}, {-5, -6}}})})); CHECK_EQ(result_literal, LiteralUtil::CreateR2( {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}})); @@ -101,8 +100,7 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate(*module, - {LiteralUtil::CreateR1({9, 10})})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR1({9, 10})})); CHECK_EQ(result_literal, LiteralUtil::CreateR2({{9, 10}})); } @@ -121,8 +119,7 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate(*module, - {LiteralUtil::CreateR1({9, 10})})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR1({9, 10})})); CHECK_EQ(result_literal, LiteralUtil::CreateR3({{{9, 10}}})); } @@ -141,7 +138,7 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate(*module, {LiteralUtil::CreateR0(9)})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(9)})); CHECK_EQ(result_literal, LiteralUtil::CreateR2({{9}})); } @@ -160,8 +157,8 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate( - *module, {LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6})})); + evaluator.Evaluate(*module, + {LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6})})); CHECK_EQ(result_literal, LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); } @@ -180,9 +177,9 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { entry_computation->set_root_instruction(zero_padded_param); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, - evaluator.Evaluate( - *module, {LiteralUtil::CreateR1({3, 4})})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR1({3, 4})})); CHECK_EQ(result_literal, LiteralUtil::CreateR1({0, 0, 0, 3, 4, 0})); } @@ -194,15 +191,14 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { /*output_shape_dims=*/{2, 2}, ¶m, &entry_computation); - TF_ASSERT_OK_AND_ASSIGN( - HloInstruction * zeros, - BroadcastZeros(module->entry_computation(), S32, {2, 2})); + HloInstruction* zeros = + BroadcastZeros(module->entry_computation(), S32, {2, 2}); entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0)})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0)})); CHECK_EQ(result_literal, LiteralUtil::CreateR2({{0, 0}, {0, 0}})); } @@ -214,15 +210,14 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { /*output_shape_dims=*/{2, 2}, ¶m, &entry_computation); - TF_ASSERT_OK_AND_ASSIGN( - HloInstruction * zeros, - BroadcastZeros(module->entry_computation(), F32, {2, 2})); + HloInstruction* zeros = + BroadcastZeros(module->entry_computation(), F32, {2, 2}); entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, - evaluator.Evaluate( - *module, {LiteralUtil::CreateR0(0.0f)})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0.0f)})); CHECK_EQ(result_literal, LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 3ed3d3c11c71dc534f193ba3ffb556b0eb0c80e4..3144a84805454488f417391f40ed6b9e9facc752 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -107,7 +107,7 @@ bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( return false; } } - if (!visited.count(user)) { + if (!visited.contains(user)) { stack.push_back(user); } } @@ -190,7 +190,7 @@ string HloDataflowAnalysis::ToString() const { for (const HloComputation* computation : module_.computations()) { for (const HloInstruction* instruction : computation->instructions()) { StrAppend(&out, " ", instruction->name(), ":\n"); - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { GetInstructionValueSet(instruction) .ForEachElement([this, &instruction, &out]( const ShapeIndex& index, @@ -256,7 +256,7 @@ bool HloDataflowAnalysis::Phi( input_value_ids.push_back(value->id()); } } - std::sort(input_value_ids.begin(), input_value_ids.end()); + absl::c_sort(input_value_ids); input_value_ids.erase( std::unique(input_value_ids.begin(), input_value_ids.end()), input_value_ids.end()); @@ -271,8 +271,7 @@ bool HloDataflowAnalysis::Phi( if (current_value_defined_here) { VLOG(5) << "current_value_defined_here: " << current_value->ToString(); CHECK(current_value->is_phi()); - auto it = std::find(input_value_ids.begin(), input_value_ids.end(), - current_value->id()); + auto it = absl::c_find(input_value_ids, current_value->id()); if (it != input_value_ids.end()) { input_value_ids.erase(it); } @@ -921,8 +920,7 @@ StatusOr> HloDataflowAnalysis::Run( for (auto& pair : dataflow_analysis->values_) { dataflow_analysis->values_vector_.push_back(&pair.second); } - std::sort(dataflow_analysis->values_vector_.begin(), - dataflow_analysis->values_vector_.end(), HloValue::IdLessThan); + absl::c_sort(dataflow_analysis->values_vector_, HloValue::IdLessThan); TF_DCHECK_OK(dataflow_analysis->Verify()); @@ -937,9 +935,7 @@ Status HloDataflowAnalysis::Verify() const { for (const HloValue* value : values()) { for (const HloPosition& position : value->positions()) { const HloValueSet& value_set = GetValueSet(position); - TF_RET_CHECK(std::find(value_set.values().begin(), - value_set.values().end(), - value) != value_set.values().end()) + TF_RET_CHECK(absl::c_linear_search(value_set.values(), value)) << "Value set at position " << position << " does not contain value " << value->ToShortString(); } @@ -954,9 +950,7 @@ Status HloDataflowAnalysis::Verify() const { const HloValueSet& value_set = pair.second; const HloPosition position{instruction, index}; for (const HloValue* value : value_set.values()) { - TF_RET_CHECK(std::find(value->positions().begin(), - value->positions().end(), - position) != value->positions().end()) + TF_RET_CHECK(absl::c_linear_search(value->positions(), position)) << "Value set at position " << position << " unexpectedly contains value " << value->ToShortString(); } @@ -1041,11 +1035,10 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // Check if one operand of kAdd fused root is kDot or kConvolution. auto* add = user->fused_expression_root(); auto add_operand_it = - std::find_if(add->operands().begin(), add->operands().end(), - [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot; - }); + absl::c_find_if(add->operands(), [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot; + }); if (add_operand_it == add->operands().end()) { return false; } @@ -1100,16 +1093,15 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // *) The root instruction of the called computation is element-wise on // 'operand'. const bool found_caller_use = - std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { + absl::c_find_if(uses, [user](const HloUse& use) { return use.instruction == user; }) != uses.end(); auto* callee_root = user->to_apply()->root_instruction(); const bool found_elementwise_callee_use = - std::find_if( - uses.begin(), uses.end(), [callee_root](const HloUse& use) { - return use.instruction == callee_root && - callee_root->IsElementwiseOnOperand(use.operand_number); - }) != uses.end(); + absl::c_find_if(uses, [callee_root](const HloUse& use) { + return use.instruction == callee_root && + callee_root->IsElementwiseOnOperand(use.operand_number); + }) != uses.end(); return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 94de7c55dd2402e55ec344b79c24af2d8283fe73..4a7c4963b7b399e625da907b3810c42df7ee2bd3 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -73,8 +73,8 @@ class HloDataflowAnalysisTest : public HloTestBase, bool InstructionsMayInterfere(const HloOrdering& ordering, const HloInstruction* a, const HloInstruction* b) { - EXPECT_FALSE(ShapeUtil::IsTuple(a->shape())); - EXPECT_FALSE(ShapeUtil::IsTuple(b->shape())); + EXPECT_FALSE(a->shape().IsTuple()); + EXPECT_FALSE(b->shape().IsTuple()); return ordering.MayInterfere(analysis_->GetValueDefinedAt(a), analysis_->GetValueDefinedAt(b), *analysis_); } @@ -1901,9 +1901,9 @@ ENTRY %AddDependency (p: f32[3]) -> f32[3] { EXPECT_FALSE(analysis->ValueIsDefinedAt(root)); } -INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, - HloDataflowAnalysisTest, - ::testing::Values(false, true)); +INSTANTIATE_TEST_SUITE_P(HloDataflowAnalysisInstantiation, + HloDataflowAnalysisTest, + ::testing::Values(false, true)); class HloDataflowAnalysisTestBase : public HloTestBase { protected: @@ -1970,12 +1970,13 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, + std::initializer_list({starts}))); builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -2012,12 +2013,13 @@ TEST_F(DoesNotUseOperandBufferTest, IndirectUses) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, + std::initializer_list({starts}))); builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -2150,17 +2152,17 @@ TEST_F(CanShareOperandBufferWithUserTest, auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "param0")); - auto index = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 0}))); - auto ds = builder.AddInstruction( - HloInstruction::CreateDynamicSlice(slice_shape, param, index, {1, 2, 2})); + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( + slice_shape, param, {zero, zero}, {1, 2, 2})); - auto dus = builder.AddInstruction( - HloInstruction::CreateDynamicUpdateSlice(data_shape, param, ds, index)); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, param, ds, {zero, zero})); BuildModule(builder.Build()); auto fusion = computation_->CreateFusionInstruction( - {dus, ds, index}, HloInstruction::FusionKind::kLoop); + {dus, ds, zero}, HloInstruction::FusionKind::kLoop); RunAnalysis(); EXPECT_TRUE( @@ -2219,12 +2221,13 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, + std::initializer_list({starts}))); builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -2259,12 +2262,13 @@ TEST_F(CanShareOperandBufferWithUserTest, // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape_bf16, convert1, update, starts)); + data_shape_bf16, convert1, update, + std::initializer_list({starts}))); auto convert2 = builder.AddInstruction( HloInstruction::CreateConvert(data_shape, dynamic_update_slice)); @@ -2290,10 +2294,13 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { HloInstruction::CreateParameter(0, data_shape, "data")); auto update = builder.AddInstruction( HloInstruction::CreateParameter(1, update_shape, "update")); - auto starts = builder.AddInstruction( - HloInstruction::CreateParameter(2, starts_shape, "starts")); + auto start0 = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "start0")); + auto start1 = builder.AddInstruction( + HloInstruction::CreateParameter(3, starts_shape, "start1")); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, data, update, starts)); + data_shape, data, update, {start0, start1})); BuildModuleAndRunAnalysis(builder.Build()); @@ -2304,7 +2311,9 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { EXPECT_FALSE( dataflow_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {})); EXPECT_FALSE( - dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); + dataflow_analysis_->CanShareOperandBufferWithUser(start0, {}, dus, {})); + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(start1, {}, dus, {})); } TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 7d35e251ca21951036336ff1a1eb4aabc87bc5ca..a5a11f09cf4f857b992e5ede3a9dbc5a937ce722 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -65,7 +66,7 @@ StatusOr HloDCE::Run(HloModule* module) { // Now DCE HloComputations. First, collect the computations that are // referenced by some remaining instruction. - std::unordered_set live_computations; + absl::flat_hash_set live_computations; if (HloComputation* entry_computation = module->entry_computation()) { live_computations.insert(entry_computation); } @@ -79,7 +80,7 @@ StatusOr HloDCE::Run(HloModule* module) { // Remove dead computations. for (auto* computation : module->MakeComputationPostOrder()) { - if (live_computations.count(computation) == 0) { + if (!live_computations.contains(computation)) { TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(computation)); changed = true; } diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 1fa4259a3e42286cbc911907eea563e6ca6f8611..b5d72b386f89568cc3066b2e497be98428d1ed0c 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -43,9 +43,7 @@ class HloDceTest : public HloTestBase { // Returns whether the given instruction exists in the given computation. bool HasInstruction(const HloComputation& computation, const HloInstruction* instruction) { - return std::find(computation.instructions().begin(), - computation.instructions().end(), - instruction) != computation.instructions().end(); + return absl::c_linear_search(computation.instructions(), instruction); } }; diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index c6d02f9f67bb599e496d20fc2acf2e627ed54438..7cdb7f6bdf26241cda4fabbb5ccaf6e6f7de39ce 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -230,10 +230,10 @@ HloDomainMap::MakeNonDomainInstructions( } } // sort instructions according to instructions_order - std::sort(instructions.begin(), instructions.end(), - [&instructions_order](HloInstruction* a, HloInstruction* b) { - return instructions_order.at(a) < instructions_order.at(b); - }); + absl::c_sort(instructions, + [&instructions_order](HloInstruction* a, HloInstruction* b) { + return instructions_order.at(a) < instructions_order.at(b); + }); return instructions; } diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index a40b6d888c548bf0909f413c092fc32cfc0a4892..9b0f2b2a0f4dd5d1d1191e9ab0637cc3034b50da 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -68,7 +68,7 @@ Shape GetConvertedTupleShape(const Shape& shape, PrimitiveType from_type, std::vector new_tuple_subshapes; for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { Shape subshape = ShapeUtil::GetTupleElementShape(shape, i); - CHECK(!ShapeUtil::IsTuple(subshape)); + CHECK(!subshape.IsTuple()); if (subshape.element_type() == from_type) { subshape = ShapeUtil::ChangeElementType(subshape, to_type); } @@ -92,7 +92,7 @@ HloInstruction* ConvertTupleElements(HloInstruction* hlo, HloInstruction* element = computation->AddInstruction( HloInstruction::CreateGetTupleElement(ele_shape, hlo, i)); const Shape& to_ele_shape = ShapeUtil::GetTupleElementShape(to_shape, i); - CHECK(!ShapeUtil::IsTuple(ele_shape)); + CHECK(!ele_shape.IsTuple()); if (ele_shape.element_type() != to_ele_shape.element_type()) { element = computation->AddInstruction( HloInstruction::CreateConvert(to_ele_shape, element)); @@ -190,7 +190,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); new_hlo = ToElementType(new_hlo, eliminate_type_); - } else if (ShapeUtil::IsTuple(hlo->shape())) { + } else if (hlo->shape().IsTuple()) { Shape old_shape = hlo->shape(); Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_, replace_with_type_); diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc index a3b56a44a0b02923585c1dcb69571479236188a3..5b633784e2f306290ca6c096f67c657be1f188c8 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc @@ -28,15 +28,7 @@ using ::testing::Eq; using ::testing::Not; using ::testing::ResultOf; -class HloElementTypeConverterTest : public HloTestBase { - public: - std::unique_ptr CreateModuleFromHloString( - const string& hlo_string) { - return HloRunner::CreateModuleFromString(hlo_string, - GetDebugOptionsForTest()) - .ValueOrDie(); - } -}; +using HloElementTypeConverterTest = HloTestBase; TEST_F(HloElementTypeConverterTest, CustomCallsNotConverted) { const string& hlo_string = R"( @@ -47,7 +39,7 @@ TEST_F(HloElementTypeConverterTest, CustomCallsNotConverted) { custom_call_target="foo" } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_FALSE(converted); @@ -63,7 +55,7 @@ TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) { outfeed = token[] outfeed(infeed.data, token0) } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_FALSE(converted); @@ -73,17 +65,16 @@ TEST_F(HloElementTypeConverterTest, OperationsInNestedTuplesConverted) { const string& hlo_string = R"( HloModule NestedTuples ENTRY NestedTuples.v5 { - constant.4 = bf16[] constant(42) constant.2 = f32[2]{0} constant({1, 2}) - constant.3 = bf16[] constant(42) - add = bf16[] add(constant.2, constant.3) - tuple = (f32[2]{0}, bf16[]) tuple(constant.2, add) + constant.3 = bf16[2]{0} constant({42, 42}) + add = bf16[2]{0} add(constant.2, constant.3) + tuple = (f32[2]{0}, bf16[2]{0}) tuple(constant.2, add) constant.5 = bf16[2]{0} constant({22, 44}) - ROOT tuple.1 = ((f32[2]{0}, bf16[]), bf16[2]{0}) tuple(tuple, constant.5) + ROOT tuple.1 = ((f32[2]{0}, bf16[2]{0}), bf16[2]{0}) tuple(tuple, constant.5) } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_TRUE(converted); @@ -111,7 +102,7 @@ TEST_F(HloElementTypeConverterTest, BatchNormGradBF16Converted) { } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_TRUE(converted); @@ -135,7 +126,7 @@ ENTRY main { ROOT rng = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), distribution=rng_uniform } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_TRUE(converted); @@ -161,7 +152,7 @@ ENTRY main { ROOT rng1 = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), control-predecessors={%rng0}, distribution=rng_uniform } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 934c082bb9f003b1d2d80835f09a8f4109c7e7fd..ecde37be58a381be7968b04de7bbe1d85d7ddb25 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include #include +#include #include #include -#include #include #include "absl/algorithm/container.h" @@ -29,7 +29,6 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -136,8 +135,44 @@ StatusOr Compare(const Shape& shape, HloOpcode opcode, return std::move(result); } +template <> +StatusOr Compare(const Shape& shape, HloOpcode opcode, + LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { + std::function compare_op; + switch (opcode) { + case HloOpcode::kEq: + compare_op = [](complex128 lhs_el, complex128 rhs_el) { + return lhs_el == rhs_el; + }; + break; + case HloOpcode::kNe: + compare_op = [](complex128 lhs_el, complex128 rhs_el) { + return lhs_el != rhs_el; + }; + break; + default: + LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " + << HloOpcodeString(opcode); + } + + Literal result(shape); + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); + + return std::move(result); +} + } // namespace +// Note that unsupported types by the typed visitor does not necessarily imply +// the non-typed HloEvaluator (parent evaluator) would not support them either +// in the type-agnostic handler. For e.g., HandleGetTupleElement in the parent +// type-agnostic evaluator will be able to accept Tuple primitive type, whereas +// HloEvaluatorTypedVisitor cannot. HloEvaluator::HloEvaluator(int64 max_loop_iterations) : max_loop_iterations_(max_loop_iterations) { typed_visitors_[PRED] = @@ -145,22 +180,14 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) typed_visitors_[U8] = absl::make_unique>(this); typed_visitors_[U16] = - absl::make_unique([](HloInstruction*) { - return Unimplemented( - "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " - "U16."); - }); + absl::make_unique>(this); typed_visitors_[U32] = absl::make_unique>(this); typed_visitors_[U64] = absl::make_unique>(this); typed_visitors_[S8] = absl::make_unique>(this); typed_visitors_[S16] = - absl::make_unique([](HloInstruction*) { - return Unimplemented( - "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " - "S16."); - }); + absl::make_unique>(this); typed_visitors_[S32] = absl::make_unique>(this); typed_visitors_[S64] = @@ -173,6 +200,8 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) absl::make_unique>(this); typed_visitors_[C64] = absl::make_unique>(this); + typed_visitors_[C128] = + absl::make_unique>(this); // Most of the evaluator computations we use don't support BF16 (e.g., // std::ceil, std::tanh). To make evaluator work with BF16, we set all @@ -198,65 +227,30 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) }); } -template -StatusOr HloEvaluator::Evaluate( - const HloModule& module, absl::Span arg_literals) { - XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); - - evaluated_.clear(); - arg_literals_.clear(); - for (const auto& literal_ptr : arg_literals) { - arg_literals_.push_back(&*literal_ptr); - } - - TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this)); - - return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction()) - .Clone(); -} - -template <> -StatusOr HloEvaluator::Evaluate( - const HloModule& module, absl::Span arg_literals) { - std::vector arg_literal_ptrs; - for (const auto& literal_ptr : arg_literals) { - arg_literal_ptrs.push_back(&literal_ptr); - } - return Evaluate(module, arg_literal_ptrs); -} - -template StatusOr HloEvaluator::Evaluate( const HloComputation& computation, - absl::Span arg_literals) { + absl::Span arg_literals) { CHECK(computation.parent() != nullptr); XLA_VLOG_LINES( 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString()); - evaluated_.clear(); - arg_literals_.clear(); - for (const auto& literal_ptr : arg_literals) { - arg_literals_.push_back(&*literal_ptr); + if (arg_literals.size() != computation.num_parameters()) { + return InvalidArgument( + "Expected %d argument%s, but got %d.", computation.num_parameters(), + computation.num_parameters() == 1 ? "" : "s", arg_literals.size()); } - - TF_RETURN_IF_ERROR(computation.Accept(this)); - return GetEvaluatedLiteralFor(computation.root_instruction()).Clone(); -} - -template <> -StatusOr HloEvaluator::Evaluate( - const HloComputation& computation, absl::Span arg_literals) { - std::vector arg_literal_ptrs; - for (const auto& literal_ptr : arg_literals) { - arg_literal_ptrs.push_back(&literal_ptr); + for (int64 i = 0; i < arg_literals.size(); ++i) { + const auto& computation_shape = + computation.parameter_instruction(i)->shape(); + const auto& arg_shape = arg_literals[i]->shape(); + if (!ShapeUtil::Equal(computation_shape, arg_shape)) { + return InvalidArgument( + "Shape mismatch at parameter %d. Computation expected %s, but arg " + "was %s.", + i, ShapeUtil::HumanStringWithLayout(computation_shape), + ShapeUtil::HumanString(arg_shape)); + } } - return Evaluate(computation, arg_literal_ptrs); -} - -template -StatusOr HloEvaluator::Evaluate( - HloInstruction* instruction, absl::Span arg_literals) { - TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); evaluated_.clear(); arg_literals_.clear(); @@ -264,33 +258,20 @@ StatusOr HloEvaluator::Evaluate( arg_literals_.push_back(&*literal_ptr); } - // Evaluate operands of Parameter type against the input literals which - // caches the evaluated literal results. - for (const auto operand : instruction->operands()) { - if (operand->opcode() == HloOpcode::kParameter) { - const Literal* input_literal = arg_literals_[operand->parameter_number()]; - VLOG(2) << "Parameter operand evaluated to: " - << input_literal->ToString(); - TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); - - evaluated_[operand] = input_literal->Clone(); - } + // Re-seed RNG, either from the configuration's seed or a monotonic + // per-evaluator seed (which prevents two evaluators from returning the same + // random sequence). + if (computation.parent()->config().seed()) { + seed_ = computation.parent()->config().seed(); + } else { + // Start global_seed at a (true) random value. + static std::atomic global_seed{std::random_device()()}; + seed_ = global_seed.fetch_add(1); } + engine_.seed(seed_); - TF_RETURN_IF_ERROR(Preprocess(instruction)); - TF_RETURN_IF_ERROR(instruction->Visit(this)); - TF_RETURN_IF_ERROR(Postprocess(instruction)); - return GetEvaluatedLiteralFor(instruction).Clone(); -} - -template <> -StatusOr HloEvaluator::Evaluate( - HloInstruction* instruction, absl::Span arg_literals) { - std::vector arg_literal_ptrs; - for (const auto& literal : arg_literals) { - arg_literal_ptrs.push_back(&literal); - } - return Evaluate(instruction, arg_literal_ptrs); + TF_RETURN_IF_ERROR(computation.Accept(this)); + return GetEvaluatedLiteralFor(computation.root_instruction()).Clone(); } StatusOr HloEvaluator::Evaluate(HloInstruction* instruction) { @@ -408,16 +389,45 @@ Status HloEvaluator::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status HloEvaluator::HandleGetDimensionSize( + HloInstruction* get_dimension_size) { + HloInstruction* operand = get_dimension_size->mutable_operand(0); + int64 dim = get_dimension_size->dimension(); + if (dynamic_dimension_inference_ == nullptr) { + return InvalidArgument( + "Evaluator cannot evaluate get_dimension_size without " + "set_dynamic_dimension_inference."); + } + HloInstruction* dynamic_size = + dynamic_dimension_inference_->GetDynamicSize(operand, {}, dim); + if (dynamic_size != nullptr) { + evaluated_[get_dimension_size] = + GetEvaluatedLiteralFor(dynamic_size).Clone(); + return Status::OK(); + } + + const Shape& shape = get_dimension_size->operand(0)->shape(); + Literal output(ShapeUtil::MakeShape(U32, {})); + output.PopulateWithValue( + static_cast(shape.dimensions(get_dimension_size->dimension()))); + evaluated_[get_dimension_size] = std::move(output); + return Status::OK(); +} + Status HloEvaluator::HandleParameter(HloInstruction* parameter) { + // Nothing to do other than sanity checks. Parameters' values are stored in + // arg_literals_. CHECK_LT(parameter->parameter_number(), arg_literals_.size()); + +#ifndef NDEBUG const Literal* input_literal = arg_literals_[parameter->parameter_number()]; VLOG(2) << "Parameter evaluated to: " << input_literal->ToString(); DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape())) << "parameter shape is: " << ShapeUtil::HumanString(parameter->shape()) << ", but input literal shape is: " << ShapeUtil::HumanString(input_literal->shape()); +#endif - evaluated_[parameter] = input_literal->Clone(); return Status::OK(); } @@ -442,8 +452,8 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { // The result concatenate dimension is going to be the sum of all // concatenate dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); - CHECK(ShapeUtil::IsArray(reference_shape)); - const int64 rank = ShapeUtil::Rank(reference_shape); + CHECK(reference_shape.IsArray()); + const int64 rank = reference_shape.rank(); const int64 concat_dim = concatenate->dimensions()[0]; CHECK_GE(concat_dim, 0); CHECK_LT(concat_dim, rank); @@ -453,7 +463,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { for (int64 i = 1; i < operands.size(); ++i) { const Shape& operand_shape = operands[i]->shape(); - CHECK(ShapeUtil::IsArray(operand_shape)); + CHECK(operand_shape.IsArray()); // Accumulate the concat dimension from all tensors taking part to the // operation. concat_dimensions[concat_dim] += @@ -530,6 +540,13 @@ Status HloEvaluator::HandleReal(HloInstruction* real) { TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); break; } + case C128: { + auto result_or = ElementWiseUnaryOpImpl( + real, [](complex128 elem_operand) { return std::real(elem_operand); }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } case F16: { auto result_or = ElementWiseUnaryOpImpl( real, [](Eigen::half elem_operand) { return elem_operand; }, @@ -560,11 +577,61 @@ Status HloEvaluator::HandleReal(HloInstruction* real) { } Status HloEvaluator::HandleImag(HloInstruction* imag) { - auto result_or = ElementWiseUnaryOpImpl( - imag, [](complex64 elem_operand) { return std::imag(elem_operand); }, - GetEvaluatedLiteralFor(imag->operand(0))); + auto operand = imag->operand(0); + switch (operand->shape().element_type()) { + case C64: { + auto result_or = ElementWiseUnaryOpImpl( + imag, [](complex64 elem_operand) { return std::imag(elem_operand); }, + GetEvaluatedLiteralFor(imag->operand(0))); + + TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or)); + break; + } + case C128: { + auto result_or = ElementWiseUnaryOpImpl( + imag, [](complex128 elem_operand) { return std::imag(elem_operand); }, + GetEvaluatedLiteralFor(imag->operand(0))); + + TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or)); + break; + } + default: + LOG(FATAL) << "HandleImag: unknown/unhandled primitive type: " + << PrimitiveType_Name(operand->shape().element_type()); + } - TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or)); + return Status::OK(); +} + +Status HloEvaluator::HandleComplex(HloInstruction* complex) { + const Literal& real = GetEvaluatedLiteralFor(complex->operand(0)); + const Literal& imag = GetEvaluatedLiteralFor(complex->operand(1)); + TF_RET_CHECK(ShapeUtil::Compatible(real.shape(), imag.shape())); + + Literal result(complex->shape()); + switch (complex->shape().element_type()) { + case C64: { + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { + return std::complex(real.Get(multi_index), + imag.Get(multi_index)); + })); + break; + } + case C128: { + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { + return std::complex(real.Get(multi_index), + imag.Get(multi_index)); + })); + break; + } + default: + LOG(FATAL) << "HandleComplex: unknown/unhandled primitive type: " + << PrimitiveType_Name(complex->shape().element_type()); + } + + evaluated_[complex] = std::move(result); return Status::OK(); } @@ -601,8 +668,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { evaluated_[compare], Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); } break; - case U16: - return Unimplemented("unhandled primitive type: U16."); + case U16: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; case U32: { TF_ASSIGN_OR_RETURN( evaluated_[compare], @@ -618,8 +688,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { evaluated_[compare], Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); } break; - case S16: - return Unimplemented("unhandled primitive type: S16."); + case S16: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; case S32: { TF_ASSIGN_OR_RETURN( evaluated_[compare], @@ -655,6 +728,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); } break; + case C128: { + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), opcode, + lhs_literal, rhs_literal)); + } break; default: LOG(FATAL) << "HandleCompare: unknown primitive type: " << PrimitiveType_Name(lhs->shape().element_type()); @@ -1036,11 +1114,9 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0)); - TF_RET_CHECK(broadcast->dimensions().size() == - ShapeUtil::Rank(operand.shape())) + TF_RET_CHECK(broadcast->dimensions().size() == operand.shape().rank()) << "broadcast dimensions is of size: " << broadcast->dimensions().size() - << " and rank of operand_to_broadcast is: " - << ShapeUtil::Rank(operand.shape()); + << " and rank of operand_to_broadcast is: " << operand.shape().rank(); // Checks that operand's dimensions are the same as the broadcast's // dimensions along the dimensions to be broadcasted. for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { @@ -1113,9 +1189,10 @@ Status HloEvaluator::HandleCall(HloInstruction* call) { } HloEvaluator embedded_evaluator; - Literal result = - embedded_evaluator.Evaluate(*computation, arg_literals) - .ConsumeValueOrDie(); + embedded_evaluator.set_dynamic_dimension_inference( + dynamic_dimension_inference_); + Literal result = embedded_evaluator.Evaluate(*computation, arg_literals) + .ConsumeValueOrDie(); evaluated_[call] = std::move(result); return Status::OK(); @@ -1131,7 +1208,9 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { fusion->fused_instructions_computation()->Clone( /*suffix=*/"clone_with_layout", &context); for (auto* instruction : cloned_fused_computation->instructions()) { - LayoutUtil::SetToDefaultLayout(instruction->mutable_shape()); + if (!LayoutUtil::HasLayout(instruction->shape())) { + LayoutUtil::SetToDefaultLayout(instruction->mutable_shape()); + } } auto readded_computation = empty_hlo_module.AddEntryComputation(std::move(cloned_fused_computation)); @@ -1145,9 +1224,10 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { } HloEvaluator embedded_evaluator; + embedded_evaluator.set_dynamic_dimension_inference( + dynamic_dimension_inference_); Literal result = - embedded_evaluator - .Evaluate(*readded_computation, arg_literals) + embedded_evaluator.Evaluate(*readded_computation, arg_literals) .ConsumeValueOrDie(); evaluated_[fusion] = std::move(result); @@ -1165,16 +1245,16 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) { auto* false_computation = conditional->false_computation(); HloEvaluator embedded_evaluator; + embedded_evaluator.set_dynamic_dimension_inference( + dynamic_dimension_inference_); Literal result; if (pred.Get({})) { - result = embedded_evaluator - .Evaluate(*true_computation, - {&true_computation_arg}) - .ConsumeValueOrDie(); + result = + embedded_evaluator.Evaluate(*true_computation, {&true_computation_arg}) + .ConsumeValueOrDie(); } else { result = embedded_evaluator - .Evaluate(*false_computation, - {&false_computation_arg}) + .Evaluate(*false_computation, {&false_computation_arg}) .ConsumeValueOrDie(); } @@ -1221,18 +1301,21 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { bool keep_going = true; int64 iteration_count = 0; HloEvaluator cond_evaluator(max_loop_iterations_); + cond_evaluator.set_dynamic_dimension_inference(dynamic_dimension_inference_); HloEvaluator loop_body_evaluator(max_loop_iterations_); + loop_body_evaluator.set_dynamic_dimension_inference( + dynamic_dimension_inference_); while (keep_going) { if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) { return InvalidArgument("Loop %s exceeded loop iteration limit (%d).", while_hlo->name(), max_loop_iterations_); } TF_ASSIGN_OR_RETURN(auto cond_val, - cond_evaluator.Evaluate(*cond_comp, {&lcv})); + cond_evaluator.Evaluate(*cond_comp, {&lcv})); keep_going = cond_val.GetFirstElement(); if (keep_going) { - TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate( - *body_comp, {&lcv})); + TF_ASSIGN_OR_RETURN(auto body_val, + loop_body_evaluator.Evaluate(*body_comp, {&lcv})); VLOG(3) << "Loop iteration result: " << body_val.ToString(); lcv = std::move(body_val); cond_evaluator.ResetVisitStates(); @@ -1243,173 +1326,172 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { return Status::OK(); } -// Key-value sort is a special snowflake: it's templated on two different -// element types, one for the keys, and one for the values. Jump through some -// hoops to make this work. namespace { -template -StatusOr EvaluateSortInternal(HloInstruction* sort, - const Literal& keys_literal, - const Literal& values_literal) { - auto rank = ShapeUtil::Rank(keys_literal.shape()); - TF_RET_CHECK( - ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape())) - << "Sort keys and values must have the same dimensions"; - TF_RET_CHECK(sort->operand_count() >= 2) << "Expected key-value sort"; - // We need to sort an array of keys and an array of values, where the - // sorted order of the values is determined by the keys. The simplest(?) - // way to do this is to go to an array-of-pairs representation, sort the - // array using the keys, and then go back to pair-of-arrays. - VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); - VLOG(3) << "HandleSort values_literal: " << values_literal.ToString(); - - if (rank == 0) { - // Nothing to sort. - return LiteralUtil::MakeTuple({&keys_literal, &values_literal}); - } - - Literal keys_result_literal(keys_literal.shape()); - Literal values_result_literal(values_literal.shape()); - std::vector zero_base(rank, 0); - std::vector increment(rank, 1); - int64 sort_dim = sort->dimensions(0); - int64 sort_dim_elements = keys_literal.shape().dimensions(sort_dim); - increment[sort_dim] = sort_dim_elements; - // Iterate through each dimension except 'sort_dim'. - TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( - keys_literal.shape(), zero_base, - AsInt64Slice(keys_literal.shape().dimensions()), increment, - [&](absl::Span indices) -> StatusOr { - // Extract a slice from the keys and values literals that correspond to - // exactly the row in dimension 'sort_dim'. - std::vector limit_indices(indices.begin(), indices.end()); - std::for_each(limit_indices.begin(), limit_indices.end(), - [](int64& index) { ++index; }); - limit_indices[sort_dim] = sort_dim_elements; - TF_ASSIGN_OR_RETURN(auto keys_to_sort, - keys_literal.Slice(indices, limit_indices) - .Reshape({sort_dim_elements})); - const auto& keys_data = keys_to_sort.data(); - TF_ASSIGN_OR_RETURN(auto values_to_sort, - values_literal.Slice(indices, limit_indices) - .Reshape({sort_dim_elements})); - const auto& values_data = values_to_sort.data(); - using kv_pair = std::pair; - std::vector key_value_vector; - key_value_vector.reserve(keys_data.size()); - for (int i = 0; i < keys_data.size(); ++i) { - key_value_vector.push_back( - std::make_pair(keys_data[i], values_data[i])); - } - std::stable_sort(key_value_vector.begin(), key_value_vector.end(), - [](const kv_pair& a, const kv_pair& b) { - return SafeLess(a.first, b.first); - }); - std::vector result_keys; - // We use a InlinedVector here because we need to convert it to an - // absl::Span later, and this would not work with std::vector. - absl::InlinedVector result_values; - for (const auto& key_value : key_value_vector) { - result_keys.push_back(key_value.first); - result_values.push_back(key_value.second); - } - Literal sorted_keys(ShapeUtil::MakeShape( - keys_literal.shape().element_type(), {sort_dim_elements})); - sorted_keys.PopulateR1(absl::Span(result_keys)); - Literal sorted_values(ShapeUtil::MakeShape( - values_literal.shape().element_type(), {sort_dim_elements})); - sorted_values.PopulateR1(absl::Span(result_values)); - std::vector slice_dimensions(rank, 1); - slice_dimensions[sort_dim] = sort_dim_elements; - std::vector start_indices(rank, 0); - TF_ASSIGN_OR_RETURN(auto sorted_keys_reshaped, - sorted_keys.Reshape(slice_dimensions)); - TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom( - sorted_keys_reshaped, start_indices, indices, slice_dimensions)); - TF_ASSIGN_OR_RETURN(auto sorted_values_reshaped, - sorted_values.Reshape(slice_dimensions)); - TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom( - sorted_values_reshaped, start_indices, indices, slice_dimensions)); - return true; - })); - - Literal result_tuple; - result_tuple = - LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal}); - VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString(); - return std::move(result_tuple); -} - -template -StatusOr EvaluateSortCurried(HloInstruction* sort, - const Literal& keys_literal, - const Literal& values_literal) { - switch (values_literal.shape().element_type()) { - case PRED: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - case F32: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - case U32: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - case S32: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - case BF16: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - default: - return InvalidArgument("Unsupported type for Sort"); - } -} - -StatusOr EvaluateSort(HloInstruction* sort, - const Literal& keys_literal, - const Literal& values_literal) { - switch (sort->operand(0)->shape().element_type()) { - case F32: - return EvaluateSortCurried(sort, keys_literal, values_literal); - case U32: - return EvaluateSortCurried(sort, keys_literal, values_literal); - case S32: - return EvaluateSortCurried(sort, keys_literal, values_literal); - case BF16: - return EvaluateSortCurried(sort, keys_literal, values_literal); +StatusOr ExtractFromIndexPositions(const Literal& from, + absl::Span indices) { + PrimitiveType type = from.shape().element_type(); + switch (type) { + case PRED: { + // We use a InlinedVector here because we need to convert it to an + // absl::Span later, and this would not work with std::vector. + absl::InlinedVector values; + for (int64 index : indices) { + values.push_back(from.Get({index})); + } + return LiteralUtil::CreateR1(values); + } + case F32: { + std::vector values; + for (int64 index : indices) { + values.push_back(from.Get({index})); + } + return LiteralUtil::CreateR1(values); + } + case U32: { + std::vector values; + for (int64 index : indices) { + values.push_back(from.Get({index})); + } + return LiteralUtil::CreateR1(values); + } + case S32: { + std::vector values; + for (int64 index : indices) { + values.push_back(from.Get({index})); + } + return LiteralUtil::CreateR1(values); + } + case BF16: { + std::vector values; + for (int64 index : indices) { + values.push_back(from.Get({index})); + } + return LiteralUtil::CreateR1(values); + } default: - return InvalidArgument("Unsupported type for Sort"); + return InvalidArgument("Unsupported type for Sort: %s", + PrimitiveType_Name(type)); } } } // namespace Status HloEvaluator::HandleSort(HloInstruction* sort) { - if (!ShapeUtil::IsTuple(sort->shape())) { + if (!sort->shape().IsTuple()) { return DefaultAction(sort); } else { - // This is a really stupid work-around for the fact it's hard to support a - // multi-value sort directly, due to the fact we need to template the - // evaluation function on all of the value types. - std::vector sort_results_backing; - for (int64 i = 0; i < sort->operand_count(); ++i) { - auto result = EvaluateSort(sort, GetEvaluatedLiteralFor(sort->operand(0)), - GetEvaluatedLiteralFor(sort->operand(i))); - if (!result.ok()) { - return result.status(); + TF_RET_CHECK(sort->operand_count() >= 2) << "Expected key-value sort"; + for (int64 i = 1; i < sort->operand_count(); ++i) { + TF_RET_CHECK(ShapeUtil::SameDimensions(sort->operand(0)->shape(), + sort->operand(i)->shape())) + << "All Sort operands must have the same dimensions"; + } + + if (VLOG_IS_ON(3)) { + for (int64 i = 0; i < sort->operand_count(); ++i) { + VLOG(3) << "HandleSort operand " << i << " literal: " + << GetEvaluatedLiteralFor(sort->operand(i)).ToString(); } - sort_results_backing.push_back( - std::move(result.ValueOrDie().DecomposeTuple()[1])); } - std::vector sort_results; - absl::c_transform(sort_results_backing, std::back_inserter(sort_results), + Shape key_shape = sort->operand(0)->shape(); + auto rank = key_shape.rank(); + PrimitiveType keys_type = key_shape.element_type(); + if (keys_type != F32 && keys_type != U32 && keys_type != S32 && + keys_type != BF16) { + return InvalidArgument("Unsupported type for Sort: %s", + PrimitiveType_Name(keys_type)); + } + std::vector result_literals; + result_literals.reserve(sort->operand_count()); + for (int64 i = 0; i < sort->operand_count(); ++i) { + result_literals.emplace_back(sort->operand(i)->shape()); + } + std::vector zero_base(rank, 0); + std::vector increment(rank, 1); + int64 sort_dim = sort->dimensions(0); + int64 sort_dim_elements = key_shape.dimensions(sort_dim); + increment[sort_dim] = sort_dim_elements; + // Iterate through each dimension except 'sort_dim'. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + key_shape, zero_base, AsInt64Slice(key_shape.dimensions()), increment, + [&](absl::Span indices) -> StatusOr { + // Extract a slice from each operand literal that corresponds to + // exactly the row in dimension 'sort_dim'. + std::vector limit_indices(indices.begin(), indices.end()); + absl::c_for_each(limit_indices, [](int64& index) { ++index; }); + limit_indices[sort_dim] = sort_dim_elements; + std::vector literals_to_sort; + literals_to_sort.reserve(sort->operand_count()); + for (int64 i = 0; i < sort->operand_count(); ++i) { + TF_ASSIGN_OR_RETURN(auto literal_to_sort, + GetEvaluatedLiteralFor(sort->operand(i)) + .Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + literals_to_sort.push_back(std::move(literal_to_sort)); + } + std::vector indices_to_sort(sort_dim_elements); + std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0); + std::stable_sort( + indices_to_sort.begin(), indices_to_sort.end(), + [keys_type, &literals_to_sort](int64 a, int64 b) { + switch (keys_type) { + case F32: { + auto key_lhs = literals_to_sort[0].Get({a}); + auto key_rhs = literals_to_sort[0].Get({b}); + return SafeLess(key_lhs, key_rhs); + } + case U32: { + auto key_lhs = literals_to_sort[0].Get({a}); + auto key_rhs = literals_to_sort[0].Get({b}); + return SafeLess(key_lhs, key_rhs); + } + case S32: { + auto key_lhs = literals_to_sort[0].Get({a}); + auto key_rhs = literals_to_sort[0].Get({b}); + return SafeLess(key_lhs, key_rhs); + } + case BF16: { + auto key_lhs = literals_to_sort[0].Get({a}); + auto key_rhs = literals_to_sort[0].Get({b}); + return SafeLess(key_lhs, key_rhs); + } + default: + // We should never reach here, because we checked earlier + // that 'key_type' is one of the cases above. + LOG(FATAL) << "Invalid key type in Sort: %s", + PrimitiveType_Name(keys_type); + return false; + } + }); + std::vector slice_dimensions(rank, 1); + slice_dimensions[sort_dim] = sort_dim_elements; + std::vector start_indices(rank, 0); + for (int64 i = 0; i < sort->operand_count(); ++i) { + TF_ASSIGN_OR_RETURN(Literal sorted_literal, + ExtractFromIndexPositions(literals_to_sort[i], + indices_to_sort)); + TF_ASSIGN_OR_RETURN(auto sorted_literal_reshaped, + sorted_literal.Reshape(slice_dimensions)); + TF_RETURN_IF_ERROR(result_literals[i].CopySliceFrom( + sorted_literal_reshaped, start_indices, indices, + slice_dimensions)); + } + return true; + })); + + std::vector literal_ptrs; + absl::c_transform(result_literals, std::back_inserter(literal_ptrs), [](const Literal& literal) { return &literal; }); - evaluated_[sort] = LiteralUtil::MakeTuple(sort_results); + + Literal result_tuple = LiteralUtil::MakeTuple(literal_ptrs); + VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString(); + + evaluated_[sort] = std::move(result_tuple); return Status::OK(); } } Status HloEvaluator::HandleReduce(HloInstruction* reduce) { - if (!ShapeUtil::IsTuple(reduce->shape())) { + if (!reduce->shape().IsTuple()) { return DefaultAction(reduce); } else { auto first_element_type = reduce->shape().tuple_shapes(0).element_type(); @@ -1424,6 +1506,27 @@ Status HloEvaluator::HandleReduce(HloInstruction* reduce) { } } +Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) { + if (!custom_call_handler_) { + // No handler is registered; this means custom-calls are not allowed. + return DefaultAction(custom_call); + } + + // Evaluate input operands so the handler has access to the operand data. + std::vector operands; + operands.reserve(custom_call->operand_count()); + for (const HloInstruction* operand : custom_call->operands()) { + operands.push_back(&GetEvaluatedLiteralFor(operand)); + } + + // Synchronously issue the handler to populate the instruction output literal. + TF_ASSIGN_OR_RETURN( + auto output, custom_call_handler_(custom_call, absl::MakeSpan(operands))); + + evaluated_[custom_call] = std::move(output); + return Status::OK(); +} + Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); return ShapeUtil::ValidateShape(hlo->shape()); @@ -1441,18 +1544,6 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) { return Status::OK(); } -// Explicit instantiation of templatized Evaluate* methods. -// -template StatusOr HloEvaluator::Evaluate( - const HloModule& module, absl::Span arg_literals); - -template StatusOr HloEvaluator::Evaluate( - const HloComputation& computation, - absl::Span arg_literals); - -template StatusOr HloEvaluator::Evaluate( - HloInstruction* instruction, absl::Span arg_literals); - namespace { template std::unique_ptr> MatmulArray2DImpl( diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index d363a51c63de6fd4246c4970f580b68f4a627df8..72ea40bcd797def3bc0765986881792b8752d9e1 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -16,13 +16,16 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ +#include #include #include "absl/container/node_hash_map.h" #include "absl/memory/memory.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -43,16 +46,24 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // specified. explicit HloEvaluator(int64 max_loop_iterations = -1); - // Evaluates an HLO module and an array of pointers to literals. - // Returns the evaluated result as a literal if successful. + // Evaluates an HLO module and an array of pointers to literals. Returns the + // evaluated result as a literal if successful. + // // Precondition: The indices of arg_literals correspond to the parameter // numbers of the HLO parameters in the computation. See comment below for an // example. - // `LiteralPtr` accepts either Literal or const Literal* - // type. - template + // + // (Dummy template arg is to reduce the overloading priority of one overload + // so that Evaluate(module, {}) resolves unambiguously.) + StatusOr Evaluate(const HloModule& module, + absl::Span arg_literals) { + return Evaluate(*module.entry_computation(), arg_literals); + } + template StatusOr Evaluate(const HloModule& module, - absl::Span arg_literals); + absl::Span arg_literals) { + return Evaluate(*module.entry_computation(), arg_literals); + } // Evaluates an HLO computation and an array of pointers to literals. // Returns the evaluated result as a literal if successful. @@ -70,29 +81,24 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number // 1 in this computation. The input literals array will then have its first // literal map to Parameter0 and the second map to Parameter1. - // `LiteralPtr` accepts either Literal or const Literal* - // type. - template + // + // (Dummy template arg is to reduce the overloading priority of one overload + // so that Evaluate(module, {}) resolves unambiguously.) + StatusOr Evaluate(const HloComputation& computation, + absl::Span arg_literals); + template StatusOr Evaluate(const HloComputation& computation, - absl::Span arg_literals); - - // Evaluates a single HLO instruction and an array of pointers to literals. - // Return the evaluated result as literal if successful. - // Precondition: - // 1. argument literals correspond to the input instruction's parameters in - // their post-ordering. - // 2. the instruction's operands must be of either Parameter or Constant type. - // `LiteralPtr` accepts either Literal or const Literal* - // type. - template - StatusOr Evaluate(HloInstruction* instruction, - absl::Span arg_literals); - - // Evaluates a single HLO instruction with constant operands. - // Returns the evaluated result as literal if successful. - // Precondition: - // 1. all operands of the input instruction are constants. - // 2. the instruction is not a Parameter operation. + absl::Span arg_literals) { + std::vector arg_literal_ptrs; + for (const auto& l : arg_literals) { + arg_literal_ptrs.push_back(&l); + } + return Evaluate(computation, arg_literal_ptrs); + } + + // Gets the value of running a single HLO instruction. + // + // All of the operands to this instruction must be constants. StatusOr Evaluate(HloInstruction* instruction); // Same as Evaluate, except returning false on error and accepts an output @@ -120,9 +126,31 @@ class HloEvaluator : public DfsHloVisitorWithDefault { const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs); + void set_dynamic_dimension_inference( + DynamicDimensionInference* dynamic_dimension_inference) { + dynamic_dimension_inference_ = dynamic_dimension_inference; + } + // Enable the fast path for certain operations like dot or convolution. void set_use_fast_path(bool value) { use_fast_path_ = value; } + // Handles evaluation of a custom-call op. + // Operand literals are provided in |operands| and implementations must + // populate |output| before returning. + using CustomCallHandler = std::function( + HloInstruction* custom_call, absl::Span operands)>; + + // Sets a handler that is called during evaluation for custom-call ops. + // If no handler is defined the default error behavior will occur. The handler + // will be provided evaluated literals for all operands and is expected to + // return an output literal of the appropriate shape. + void set_custom_call_handler( + std::function(HloInstruction* custom_call, + absl::Span operands)> + handler) { + custom_call_handler_ = std::move(handler); + } + // Returns the result of a matrix multiply `lhs x rhs`. static std::unique_ptr> MatmulArray2D( const Array2D& lhs, const Array2D& rhs); @@ -158,6 +186,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleGetDimensionSize(HloInstruction* get_dimension_size) override; + Status HandleParameter(HloInstruction* parameter) override; Status HandleConstant(HloInstruction* constant) override; @@ -204,16 +234,51 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleImag(HloInstruction* imag) override; + Status HandleComplex(HloInstruction* complex) override; + Status HandleReduce(HloInstruction* reduce) override; + Status HandleCustomCall(HloInstruction* custom_call) override; + + // Unsupported HLOs, note some of them (such as BatchNorm*) are typically + // expanded in a semantic-preserving way into other HLOs by adding exanpsion + // HLO pass to the HLO optimization pass during compilation, which can then be + // handled by the evaluator. + Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override { + return Unimplemented("BatchNormGrad HLO is unsupported by the evaluator."); + }; + Status HandleBatchNormInference( + HloInstruction* batch_norm_inference) override { + return Unimplemented( + "BatchNormInference HLO is unsupported by the evaluator."); + }; + Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override { + return Unimplemented( + "BatchNormTraining HLO is unsupported by the evaluator."); + }; + Status HandleInfeed(HloInstruction* infeed) override { + return Unimplemented("Infeed HLO is unsupported by the evaluator."); + }; + Status HandleOutfeed(HloInstruction* outfeed) override { + return Unimplemented("Outfeed HLO is unsupported by the evaluator."); + }; + // Returns the already-evaluated literal result for the instruction. + // // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. + // + // Similarly, a Parameter instruction is considered evaluated and its literal + // is looked up in arg_literals. + // // Crash with log if the given instruction has not been evaluated previously. const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) { if (hlo->IsConstant()) { return hlo->literal(); } + if (hlo->opcode() == HloOpcode::kParameter) { + return *arg_literals_.at(hlo->parameter_number()); + } auto it = evaluated_.find(hlo); CHECK(it != evaluated_.end()) << "could not find evaluated value for: " << hlo->ToString(); @@ -221,12 +286,18 @@ class HloEvaluator : public DfsHloVisitorWithDefault { } // Tracks the HLO instruction and its evaluated literal result. + // + // Parameters and constants aren't stored here, see implementation of + // GetEvaluatedLiteralFor. + // // TODO(b/35950897): have better memory management here to free instructions // that are no longer a parent for any other subsequent instruction in // post-orderring. + // // Must be cleared for each evaluation. - // Storing Literal in place require the container to have pointer stability so - // we cannot use flat_hash_map any more. + // + // Storing Literal in place requires the container to have pointer stability + // so we cannot use flat_hash_map any more. absl::node_hash_map evaluated_; // Use fast path that uses eigen in the evaluator. @@ -262,6 +333,20 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Max loop iterations to execute with no maximum if negative. int64 max_loop_iterations_; + // Module-level seed handle. + uint64 seed_; + // RNG engine. + std::minstd_rand0 engine_; + + // DynamicDimensionInference is used to evaluate GetDimensionSize, which + // returns the dynamic dimension size of its operand. + DynamicDimensionInference* dynamic_dimension_inference_; + + // Optional handler for custom_call ops. + std::function(HloInstruction* custom_call, + absl::Span operands)> + custom_call_handler_; + TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); }; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 8fa493a8732662d5357a68937bfad7ac2b3b8c5d..9bc71c9d6c5e3ed5a3de2d6320762bde6005d3c0 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -51,20 +51,18 @@ namespace { static std::array use_bf16_params{true, false}; -class HloEvaluatorTest : public ::testing::WithParamInterface, - public HloTestBase { - protected: - HloEvaluatorTest() : HloTestBase(), use_bfloat16_(GetParam()) { - evaluator_ = absl::make_unique(); - } +// Test fixture for the HloEvaluator. +// +// In bf16 mode, all f32 shapes are converted to bf16 before running. +class HloEvaluatorTest : public HloTestBase { + public: + HloEvaluatorTest() : use_bfloat16_(false) {} Literal Evaluate(absl::Span arg_literals = {}) { if (use_bfloat16_) { - // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. - auto type_converter = HloElementTypeConverter(F32, BF16); - type_converter.Run(m_.get()).ValueOrDie(); + HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie(); } - return evaluator_->Evaluate(*m_->entry_computation(), arg_literals) + return evaluator_.Evaluate(*m_->entry_computation(), arg_literals) .ConsumeValueOrDie(); } @@ -74,16 +72,12 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, Literal EvaluateWithModule( HloModule* module, absl::Span arg_literals = {}) { if (use_bfloat16_) { - // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. - auto type_converter = HloElementTypeConverter(F32, BF16); - type_converter.Run(module).ValueOrDie(); + HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie(); } - return evaluator_->Evaluate(*module->entry_computation(), arg_literals) + return evaluator_.Evaluate(*module->entry_computation(), arg_literals) .ConsumeValueOrDie(); } - std::unique_ptr evaluator_; - void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input, float aabs = 0) { HloComputation::Builder b(TestName()); @@ -117,16 +111,27 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } - bool use_bfloat16_; + protected: + explicit HloEvaluatorTest(bool use_bfloat16) : use_bfloat16_(use_bfloat16) {} + HloEvaluator evaluator_; + + const bool use_bfloat16_; std::unique_ptr m_ = CreateNewVerifiedModule(); }; -#define XLA_TYPED_TEST_P(test_case_name, test_name, test_type1) \ - TEST_P(test_case_name, test_name) +// Lets you write TEST_Ps that run twice, once with and once without bf16. +class HloEvaluatorBf16Test : public ::testing::WithParamInterface, + public HloEvaluatorTest { + protected: + HloEvaluatorBf16Test() : HloEvaluatorTest(/*use_bfloat16=*/GetParam()) {} +}; + +INSTANTIATE_TEST_SUITE_P(HloEvaluatorTest_Instantiation, HloEvaluatorBf16Test, + ::testing::ValuesIn(use_bf16_params)); // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp // with 3 operands. -TEST_P(HloEvaluatorTest, DoesClamp) { +TEST_P(HloEvaluatorBf16Test, DoesClamp) { auto low = LiteralUtil::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); auto value = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); auto high = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); @@ -147,7 +152,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { +TEST_P(HloEvaluatorBf16Test, DISABLED_DoesClampSpecialBroadcast) { auto low = LiteralUtil::CreateR0(0.f); auto value = LiteralUtil::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); auto high = LiteralUtil::CreateR0(1.f); @@ -170,7 +175,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { // Verifies that HloEvaluator evaluates a HLO instruction that performs select // with 3 operands. -TEST_P(HloEvaluatorTest, DoesSelect) { +TEST_P(HloEvaluatorBf16Test, DoesSelect) { auto pred = LiteralUtil::CreateR2({{true, false}, {false, true}}); auto on_true = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); auto on_false = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); @@ -195,7 +200,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise addition with 2 operands. -TEST_P(HloEvaluatorTest, DoesAdd) { +TEST_F(HloEvaluatorTest, DoesAdd) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{3, 4}, {-96, 8}}); @@ -204,7 +209,7 @@ TEST_P(HloEvaluatorTest, DoesAdd) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise and with 2 operands. -TEST_P(HloEvaluatorTest, DoesAnd) { +TEST_P(HloEvaluatorBf16Test, DoesAnd) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{0, 0}, {4, 4}}); @@ -213,7 +218,7 @@ TEST_P(HloEvaluatorTest, DoesAnd) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise or with 2 operands. -TEST_P(HloEvaluatorTest, DoesOr) { +TEST_F(HloEvaluatorTest, DoesOr) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{3, 4}, {-100, 4}}); @@ -222,7 +227,7 @@ TEST_P(HloEvaluatorTest, DoesOr) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise or with 2 operands. -TEST_P(HloEvaluatorTest, DoesXor) { +TEST_F(HloEvaluatorTest, DoesXor) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{3, 4}, {-104, 0}}); @@ -231,7 +236,7 @@ TEST_P(HloEvaluatorTest, DoesXor) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise multiply with 2 operands. -TEST_P(HloEvaluatorTest, DoesMultiply) { +TEST_F(HloEvaluatorTest, DoesMultiply) { auto lhs = LiteralUtil::CreateR2({{-1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2( {{std::numeric_limits::min(), 4}, {4, 4}}); @@ -242,14 +247,14 @@ TEST_P(HloEvaluatorTest, DoesMultiply) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise divide with 2 operands. -TEST_P(HloEvaluatorTest, DoesDivideInt64) { +TEST_F(HloEvaluatorTest, DoesDivideInt64) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{0, 0}, {-25, 1}}); TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), std::move(rhs)); } -TEST_P(HloEvaluatorTest, DoesDivideDouble) { +TEST_P(HloEvaluatorBf16Test, DoesDivideDouble) { auto lhs = LiteralUtil::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); auto rhs = LiteralUtil::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); auto expected = @@ -260,41 +265,41 @@ TEST_P(HloEvaluatorTest, DoesDivideDouble) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise abs op with 1 operand. -TEST_P(HloEvaluatorTest, DoesAbsR2) { +TEST_F(HloEvaluatorTest, DoesAbsR2) { auto operand = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); auto expected = LiteralUtil::CreateR2({{1, 20}, {100, 4}}); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_P(HloEvaluatorTest, DoesAbsR0) { +TEST_P(HloEvaluatorBf16Test, DoesAbsR0) { auto operand = LiteralUtil::CreateR0(-1.0f); auto expected = LiteralUtil::CreateR0(1.0f); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_P(HloEvaluatorTest, DoesAbsR1WithZeroSize) { +TEST_P(HloEvaluatorBf16Test, DoesAbsR1WithZeroSize) { auto operand = LiteralUtil::CreateR1({}); auto expected = LiteralUtil::CreateR1({}); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_P(HloEvaluatorTest, DoesNegateR2) { +TEST_F(HloEvaluatorTest, DoesNegateR2) { auto operand = LiteralUtil::CreateR2( {{0, std::numeric_limits::min()}, {-1, 4}}); auto expected = LiteralUtil::CreateR2( {{0, std::numeric_limits::min()}, {1, -4}}); TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand)); } -TEST_P(HloEvaluatorTest, DoesCosR2) { +TEST_P(HloEvaluatorBf16Test, DoesCosR2) { auto operand = LiteralUtil::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = LiteralUtil::CreateR2({{1, -1}, {-1, 1}}); TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand), use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } -TEST_P(HloEvaluatorTest, DoesSinR2) { +TEST_P(HloEvaluatorBf16Test, DoesSinR2) { auto operand = LiteralUtil::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = LiteralUtil::CreateR2({{0, 0}, {0, 0}}); TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand), use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } -TEST_P(HloEvaluatorTest, DoesNotR2) { +TEST_F(HloEvaluatorTest, DoesNotR2) { auto operand = LiteralUtil::CreateR2({{0, std::numeric_limits::min()}, {-1, std::numeric_limits::max()}}); @@ -303,9 +308,22 @@ TEST_P(HloEvaluatorTest, DoesNotR2) { {0, std::numeric_limits::min()}}); TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand)); } + +TEST_F(HloEvaluatorTest, DoesRealC128) { + auto x = LiteralUtil::CreateR1({{1, 0}, {-100, 4}}); + auto expected_real = LiteralUtil::CreateR1({1, -100}); + TestUnaryOp(HloOpcode::kReal, std::move(expected_real), std::move(x)); +} + +TEST_F(HloEvaluatorTest, DoesImagC128) { + auto x = LiteralUtil::CreateR1({{1, 0}, {-100, 4}}); + auto expected_imag = LiteralUtil::CreateR1({0, 4}); + TestUnaryOp(HloOpcode::kImag, std::move(expected_imag), std::move(x)); +} + // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. -TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { +TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto rhs2 = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); @@ -335,7 +353,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { } // Verifies Reshape operation is correctly evaluated. -TEST_P(HloEvaluatorTest, DoesReshape) { +TEST_F(HloEvaluatorTest, DoesReshape) { HloComputation::Builder b(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, @@ -361,7 +379,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { } // Verifies Broadcast operation is correctly evaluated. -TEST_P(HloEvaluatorTest, DoesBroadcast) { +TEST_F(HloEvaluatorTest, DoesBroadcast) { HloComputation::Builder b(TestName()); auto input_literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); auto output_literal = LiteralUtil::CreateR3( @@ -377,7 +395,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } -TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { +TEST_F(HloEvaluatorTest, DoesBroadcastScalar) { HloComputation::Builder b(TestName()); auto input_literal = LiteralUtil::CreateR0(111); auto output_literal = LiteralUtil::CreateR2( @@ -396,7 +414,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } -TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { +TEST_F(HloEvaluatorTest, DoesConcatenateSimple) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant( @@ -418,7 +436,7 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { +TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction( @@ -439,7 +457,7 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { +TEST_P(HloEvaluatorBf16Test, ConvertWithSameLayout) { HloComputation::Builder b(TestName()); auto input_literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); @@ -458,7 +476,7 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } -TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { +TEST_P(HloEvaluatorBf16Test, ConvertWithDifferentLayout) { HloComputation::Builder b(TestName()); auto input_literal = LiteralUtil::CreateR2WithLayout( @@ -491,7 +509,7 @@ PaddingConfig CreatePaddingConfig( return padding_config; } -TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { +TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { auto operand = LiteralUtil::CreateR2({{}, {}}); HloComputation::Builder b(TestName()); auto operand_instruction = @@ -516,7 +534,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { +TEST_P(HloEvaluatorBf16Test, Pad4DFloatArrayWithInteriorPadding) { HloComputation::Builder b(TestName()); Array4D input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); @@ -551,7 +569,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, NegativePadding2D) { +TEST_P(HloEvaluatorBf16Test, NegativePadding2D) { HloComputation::Builder b(TestName()); // input_array: @@ -593,7 +611,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250))); } -TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { +TEST_P(HloEvaluatorBf16Test, NegativeAndInteriorPadding2D) { HloComputation::Builder b(TestName()); // f32[4,3] { @@ -632,7 +650,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DotRank2AndRank1) { +TEST_P(HloEvaluatorBf16Test, DotRank2AndRank1) { HloComputation::Builder b(TestName()); // lhs: @@ -678,7 +696,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DotRank1AndRank2) { +TEST_P(HloEvaluatorBf16Test, DotRank1AndRank2) { HloComputation::Builder b(TestName()); // lhs: @@ -716,7 +734,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DotRank2AndRank2) { +TEST_P(HloEvaluatorBf16Test, DotRank2AndRank2) { HloComputation::Builder b(TestName()); // lhs: @@ -766,7 +784,51 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, SimpleConv1D) { +TEST_P(HloEvaluatorBf16Test, DotRank4AndRank4) { + HloComputation::Builder b(TestName()); + + auto lhs_array = absl::make_unique>(2, 2, 3, 1); + lhs_array->FillIota(1.0f); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(*lhs_array); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + auto rhs_array = absl::make_unique>(2, 2, 3, 1); + rhs_array->FillIota(2.0f); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(*rhs_array); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Shape shape = ShapeUtil::MakeShape(F32, {2, 1, 1}); + DotDimensionNumbers dot_dnums; + + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(0); + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(2); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums, + DefaultPrecisionConfig(2))); + m_->AddEntryComputation(b.Build()); + + Literal result = Evaluate(); + float expected_1 = 0; + for (float i = 1.0f; i < 7.0f; ++i) { + expected_1 += i * i + i; + } + float expected_2 = 0; + for (float i = 7.0f; i < 13.0f; ++i) { + expected_2 += i * i + i; + } + auto expected_array = Array3D({{{expected_1}}, {{expected_2}}}); + auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); + + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + +TEST_P(HloEvaluatorBf16Test, SimpleConv1D) { HloComputation::Builder b(TestName()); Array3D lhs_array = {{{1, 2, 3}}}; @@ -815,7 +877,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { +TEST_P(HloEvaluatorBf16Test, Simple4x4Conv2DWith2x2Kernel) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -878,7 +940,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { +TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensionsReversed) { HloComputation::Builder b(TestName()); // clang-format off @@ -959,7 +1021,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { +TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensions) { HloComputation::Builder b(TestName()); // clang-format off @@ -1037,7 +1099,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { +TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithHighPadding) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -1101,7 +1163,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { +TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithLowAndHighPadding) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -1166,7 +1228,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, +TEST_P(HloEvaluatorBf16Test, DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides) { HloComputation::Builder b(TestName()); @@ -1239,7 +1301,7 @@ TEST_P(HloEvaluatorTest, EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { +TEST_P(HloEvaluatorBf16Test, Conv2DGroupedConvolution) { HloComputation::Builder b(TestName()); std::vector input_dims = {1, 2, 2, 4}; std::vector filter_dims = {2, 2, 2, 8}; @@ -1375,7 +1437,7 @@ void BM_ReducePrecisely(int num_iters) { BENCHMARK(BM_ReducePrecisely); -TEST_P(HloEvaluatorTest, ReduceAdd) { +TEST_P(HloEvaluatorBf16Test, ReduceAdd) { HloComputation::Builder b(TestName()); // arg: @@ -1417,7 +1479,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ReduceWindowMax) { +TEST_P(HloEvaluatorBf16Test, ReduceWindowMax) { HloComputation::Builder b(TestName()); // arg: @@ -1468,7 +1530,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { +TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxWindowDilation) { HloComputation::Builder b(TestName()); // arg: @@ -1520,7 +1582,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ReduceWindowAdd) { +TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd) { HloComputation::Builder b(TestName()); // arg: @@ -1577,7 +1639,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { +TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd6D) { HloComputation::Builder b(TestName()); // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. @@ -1640,7 +1702,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result)); } -TEST_P(HloEvaluatorTest, StridedSlice) { +TEST_P(HloEvaluatorBf16Test, StridedSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1674,7 +1736,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DynamicSlice) { +TEST_P(HloEvaluatorBf16Test, DynamicSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1690,12 +1752,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); - auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); + auto zero = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto one = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); - b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, - start_indices, {2, 3})); + b.AddInstruction( + HloInstruction::CreateDynamicSlice(shape, operand, {zero, one}, {2, 3})); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1710,7 +1774,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { // Verifies that the HloEvaluator's implementation goes along with existing // backends' behavior, although this is not required by the spec. -TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { +TEST_P(HloEvaluatorBf16Test, DynamicSliceModSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1726,12 +1790,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); - auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2, 1}))); + auto two = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); + auto one = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); - b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, - start_indices, {2, 3})); + b.AddInstruction( + HloInstruction::CreateDynamicSlice(shape, operand, {two, one}, {2, 3})); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1744,7 +1810,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { +TEST_P(HloEvaluatorBf16Test, DynamicSliceUpdate) { HloComputation::Builder b(TestName()); // arg: @@ -1760,15 +1826,17 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); - auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); + auto zero = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto one = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto update = b.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{-2.0, -3.0}, {-6.0, -7.0}}))); Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - shape, operand, update, start_indices)); + shape, operand, update, {zero, one})); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1781,7 +1849,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, SetAndGetTuples) { +TEST_P(HloEvaluatorBf16Test, SetAndGetTuples) { HloComputation::Builder b(TestName()); // arg: @@ -1817,7 +1885,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { +TEST_P(HloEvaluatorBf16Test, SetAndGetNestedTuples) { HloComputation::Builder b(TestName()); // arg: @@ -1856,7 +1924,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Reverse) { +TEST_P(HloEvaluatorBf16Test, Reverse) { HloComputation::Builder b(TestName()); // Input shape is float[4x3x2x1]. @@ -1909,7 +1977,7 @@ TEST_P(HloEvaluatorTest, Reverse) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { +TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutions) { HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4}); @@ -1933,7 +2001,7 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { // Check that EvaluateWithSubstitutions works if one of the operands to the op // we're evaluating is a constant. -TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { +TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutionsWithConstantOperand) { HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4}); @@ -1956,7 +2024,7 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { LiteralUtil::CreateR1({11, 22, 33, 44}), result.ValueOrDie())); } -TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { const char* hlo_text = R"( HloModule TensorFlowGatherV1 @@ -1980,7 +2048,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { const char* hlo_text = R"( HloModule TensorFlowGatherV2 @@ -2004,7 +2072,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { const char* hlo_text = R"( HloModule TensorFlowGatherMultipleBatchDims @@ -2029,7 +2097,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { const char* hlo_text = R"( HloModule TensorFlowGatherNd @@ -2055,7 +2123,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim) { const char* hlo_text = R"( HloModule TensorFlowGatherNd @@ -2082,7 +2150,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { +TEST_F(HloEvaluatorTest, EvaluateGather_DynamicSlice) { const char* hlo_text = R"( HloModule DynamicSlice @@ -2105,7 +2173,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { +TEST_F(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { const char* hlo_text = R"( HloModule BatchDynamicSlice @@ -2129,7 +2197,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { +TEST_F(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { const char* hlo_text = R"( HloModule TensorFlowGatherV1 @@ -2151,7 +2219,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { +TEST_F(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { const string hlo_text = R"( HloModule GatherXd @@ -2176,7 +2244,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { const char* hlo_text = R"( HloModule TensorFlowScatterV1 @@ -2207,7 +2275,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { const char* hlo_text = R"( HloModule TensorFlowScatterV2 @@ -2239,7 +2307,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { const char* hlo_text = R"( HloModule TensorFlowScatter @@ -2271,7 +2339,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { const char* hlo_text = R"( HloModule TensorFlowScatter @@ -2303,7 +2371,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) { +TEST_P(HloEvaluatorBf16Test, EvaluateScatter_TensorFlowScatter_F32) { const char* hlo_text = R"( HloModule TensorFlowScatter @@ -2337,7 +2405,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}), ErrorSpec{0.1, 0.01})); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { const char* hlo_text = R"( HloModule TensorFlowScatter @@ -2369,7 +2437,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { const char* hlo_text = R"( HloModule TensorFlowScatterMultipleBatchDims @@ -2402,7 +2470,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { const char* hlo_text = R"( HloModule TensorFlowScatterNd @@ -2438,7 +2506,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim) { const char* hlo_text = R"( HloModule TensorFlowScatterNdNonDefaultIndexVectorDim @@ -2475,7 +2543,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { +TEST_F(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { const char* hlo_text = R"( HloModule DynamicUpdateSlice @@ -2507,7 +2575,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { +TEST_F(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { const char* hlo_text = R"( HloModule BatchDynamicUpdateSlice @@ -2539,7 +2607,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { +TEST_F(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { const char* hlo_text = R"( HloModule TensorFlowScatter_ZeroDimBounds @@ -2568,7 +2636,7 @@ ENTRY main { operand, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { +TEST_F(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { const string hlo_text = R"( HloModule Scatter_NoUpdateWindowDims @@ -2601,7 +2669,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_NegativeIndices) { +TEST_F(HloEvaluatorTest, EvaluateScatter_NegativeIndices) { const char* hlo_text = R"( HloModule TensorFlowScatter_NegativeIndices @@ -2636,7 +2704,7 @@ ENTRY main { {&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_OobIndices) { +TEST_F(HloEvaluatorTest, EvaluateScatter_OobIndices) { const string hlo_text = R"( HloModule BatchDynamicUpdateSlice @@ -2672,7 +2740,7 @@ ENTRY main { {&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) { +TEST_F(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) { const char* hlo_text = R"( HloModule TensorFlowScatterNd_OobUpdateWindow @@ -2711,7 +2779,7 @@ ENTRY main { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise comparison with 2 bfloat16 operands. -TEST_P(HloEvaluatorTest, DoesCompareBF16) { +TEST_F(HloEvaluatorTest, DoesCompareBF16) { // lhs >= rhs auto lhs = LiteralUtil::CreateR2( {{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)}, @@ -2725,7 +2793,7 @@ TEST_P(HloEvaluatorTest, DoesCompareBF16) { std::move(rhs)); } -TEST_P(HloEvaluatorTest, Bf16Reduction) { +TEST_P(HloEvaluatorBf16Test, Bf16Reduction) { const string hlo_text = R"( HloModule Bf16Reduction @@ -2749,7 +2817,7 @@ ENTRY main { EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg}))); } -TEST_P(HloEvaluatorTest, SliceWithDifferentLayout) { +TEST_P(HloEvaluatorBf16Test, SliceWithDifferentLayout) { // Regression test for b/114735354. const string hlo_text = R"( HloModule SliceWithDifferentLayout @@ -2768,7 +2836,7 @@ ENTRY main { EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual)); } -TEST_P(HloEvaluatorTest, Bitcast) { +TEST_P(HloEvaluatorBf16Test, Bitcast) { // Regression test for b/114735354. constexpr absl::string_view hlo_text_base = R"( HloModule Bitcast @@ -2795,8 +2863,261 @@ ENTRY main { } } -INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, - ::testing::ValuesIn(use_bf16_params)); +// Check that s32 under/overflow doesn't trigger a ubsan failure. +TEST_F(HloEvaluatorTest, Int32Overflow) { + constexpr absl::string_view hlo_text = R"( +HloModule Test + +ENTRY main { + c1 = s32[] constant(1073741824) // 2^30 + sum = s32[] add(c1, c1) // 2^31, i.e. INT_MIN + + c2 = s32[] constant(-2147483648) // -2^31 + sub = s32[] subtract(c2, c1) // -2^31 - 2^30, underflows + + mul = s32[] multiply(c1, c1) + ROOT tuple = (s32[], s32[], s32[]) tuple(sum, sub, mul) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + std::vector actual = Evaluate({}).DecomposeTuple(); + ASSERT_EQ(actual.size(), 3); + + uint32 pow30 = uint32{1} << 30; + uint32 pow31 = uint32{1} << 31; + EXPECT_EQ(actual[0].GetFirstElement(), static_cast(pow31)); + EXPECT_EQ(actual[1].GetFirstElement(), + static_cast(-(pow31 + pow30))); + EXPECT_EQ(actual[2].GetFirstElement(), + static_cast(pow31 * pow31)); +} + +TEST_F(HloEvaluatorTest, GetDimensionSize) { + constexpr absl::string_view hlo_text = R"( +HloModule Test + +ENTRY main { + size = u32[] parameter(0) + + data = s32[4] parameter(1) + + sum = s32[4] add(data, data) + + ROOT dynamic_size = u32[] get-dimension-size(sum), dimensions={0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + + // Set up dynamic parameter binding. + TF_CHECK_OK(m_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{0, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 0})); + + TF_ASSERT_OK_AND_ASSIGN(DynamicDimensionInference dynamic_dimension_inference, + DynamicDimensionInference::Run(m_.get())); + + evaluator_.set_dynamic_dimension_inference(&dynamic_dimension_inference); + Literal size_arg = LiteralUtil::CreateR0(3); + Literal data_arg = LiteralUtil::CreateR1({1, 2, 3, 4}); + + Literal actual = Evaluate({&size_arg, &data_arg}); + + EXPECT_EQ(actual.GetFirstElement(), static_cast(3)); +} + +// Check that we get a useful error if we pass inputs of the wrong shape. +TEST_F(HloEvaluatorTest, EvaluateWithWrongInputShapes) { + constexpr absl::string_view hlo_text = R"( +HloModule Test + +ENTRY main { + p0 = s32[1] parameter(0) + ROOT sum = s32[1] add(p0, p0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + Literal input_wrong_shape = LiteralUtil::CreateR1({0, 1}); + + EXPECT_EQ(HloEvaluator() + .Evaluate(*m_, {&input_wrong_shape}) + .status() + .error_message(), + "Shape mismatch at parameter 0. Computation expected s32[1]{0}, " + "but arg was s32[2]."); + EXPECT_EQ(HloEvaluator() + .Evaluate(*m_->entry_computation(), {&input_wrong_shape}) + .status() + .error_message(), + "Shape mismatch at parameter 0. Computation expected s32[1]{0}, " + "but arg was s32[2]."); +} + +// Check that we get a useful error if we pass too many or too few inputs. +TEST_F(HloEvaluatorTest, EvaluateWithWrongNumberOfInputs) { + constexpr absl::string_view hlo_text = R"( +HloModule Test + +ENTRY main { + p0 = s32[1] parameter(0) + ROOT sum = s32[1] add(p0, p0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + Literal input = LiteralUtil::CreateR1({0}); + + EXPECT_EQ( + HloEvaluator().Evaluate(*m_, {&input, &input}).status().error_message(), + "Expected 1 argument, but got 2."); + EXPECT_EQ(HloEvaluator() + .Evaluate(*m_->entry_computation(), {&input, &input}) + .status() + .error_message(), + "Expected 1 argument, but got 2."); +} + +TEST_F(HloEvaluatorTest, PreserveFusionInputLayout) { + constexpr absl::string_view hlo_text = R"( + HloModule FusionInputLayout + + fused_computation { + param_0 = f32[20,20]{0,1} parameter(0) + ROOT bitcast = f32[20,20]{1,0} bitcast(param_0) + } + + ENTRY kernel_entry { + parameter.0 = f32[20,20]{0,1} parameter(0) + ROOT fusion = f32[20,20]{1,0} fusion(parameter.0), + kind=kLoop, calls=fused_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + Literal actual = Evaluate({&args[0]}); + EXPECT_TRUE(absl::c_equal(args[0].data(), actual.data())); +} + +TEST_F(HloEvaluatorTest, PreserveFusionOutputLayout) { + constexpr absl::string_view hlo_text = R"( + HloModule FusionOutputLayout + + fused_computation { + param_0 = f32[20,20]{1,0} parameter(0) + ROOT bitcast = f32[20,20]{0,1} bitcast(param_0) + } + + ENTRY kernel_entry { + parameter.0 = f32[20,20]{1,0} parameter(0) + ROOT fusion = f32[20,20]{0,1} fusion(parameter.0), + kind=kLoop, calls=fused_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + Literal actual = Evaluate({&args[0]}); + EXPECT_TRUE(absl::c_equal(args[0].data(), actual.data())); +} + +TEST_F(HloEvaluatorTest, PreserveMOFusionOutputLayout) { + constexpr absl::string_view hlo_text = R"( + HloModule MOFusionOutputLayout + + fused_computation { + param_0 = f32[20,20]{1,0} parameter(0) + bitcast = f32[20,20]{0,1} bitcast(param_0) + ROOT tuple = (f32[20,20]{0,1}) tuple(bitcast) + } + + ENTRY kernel_entry { + parameter.0 = f32[20,20]{1,0} parameter(0) + ROOT fusion = (f32[20,20]{0,1}) fusion(parameter.0), + kind=kLoop, calls=fused_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + Literal actual_tuple = Evaluate({&args[0]}); + std::vector actual_literals = actual_tuple.DecomposeTuple(); + EXPECT_TRUE( + absl::c_equal(args[0].data(), actual_literals[0].data())); +} + +// Tests that custom_calls fail to evaluate when no handler is specified. +TEST_F(HloEvaluatorTest, EvaluateCustomCall_NoHandler) { + constexpr absl::string_view hlo_text = R"( + HloModule EvaluateCustomCall_NoHandler + ENTRY kernel_entry { + parameter.0 = u32[2,2]{1,0} parameter(0) + ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0), + custom_call_target="_my_custom_call" + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + EXPECT_EQ(HloEvaluator().Evaluate(*m_, {&args[0]}).status().code(), + ::tensorflow::error::UNIMPLEMENTED); +} + +// Tests when a custom_call handler returns an error. +TEST_F(HloEvaluatorTest, EvaluateCustomCall_HandlerError) { + constexpr absl::string_view hlo_text = R"( + HloModule EvaluateCustomCall_HandlerError + ENTRY kernel_entry { + parameter.0 = u32[2,2]{1,0} parameter(0) + ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0), + custom_call_target="_my_custom_call" + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + HloEvaluator evaluator; + evaluator.set_custom_call_handler( + [](HloInstruction* custom_call, absl::Span operands) { + return InternalError("Test error"); + }); + EXPECT_EQ(evaluator.Evaluate(*m_, {&args[0]}).status().code(), + ::tensorflow::error::INTERNAL); +} + +// Tests the custom_call handler on calls with many inputs. +// We sum the operands so that we can verify the operand and output literals +// are properly mapped for access. +TEST_F(HloEvaluatorTest, EvaluateCustomCall_ManyInputs) { + constexpr absl::string_view hlo_text = R"( + HloModule EvaluateCustomCall_ManyInputs + ENTRY kernel_entry { + parameter.0 = u32[1]{0} parameter(0) + parameter.1 = u32[1]{0} parameter(1) + ROOT test_root = u32[1]{0} custom-call(parameter.0, parameter.1), + custom_call_target="_my_custom_call" + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + HloEvaluator evaluator; + evaluator.set_custom_call_handler( + [](HloInstruction* custom_call, absl::Span operands) { + EXPECT_EQ(HloOpcode::kCustomCall, custom_call->opcode()); + EXPECT_EQ("_my_custom_call", custom_call->custom_call_target()); + EXPECT_EQ(2, custom_call->operand_count()); + EXPECT_EQ(2, operands.size()); + auto output = Literal::CreateFromShape(custom_call->shape()); + auto operand0_data = operands[0]->data(); + auto operand1_data = operands[1]->data(); + auto output_data = output.data(); + output_data[0] = operand0_data[0] + operand1_data[0]; + return output; + }); + TF_ASSERT_OK_AND_ASSIGN( + Literal actual_literal, + evaluator.Evaluate(*m_->entry_computation(), {&args[0], &args[1]})); + auto arg0_data = args[0].data(); + auto arg1_data = args[1].data(); + std::vector expected_data = {arg0_data[0] + arg1_data[0]}; + EXPECT_TRUE(absl::c_equal(expected_data, actual_literal.data())); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 3ace2f544329253d217e1891ce387a8a55fe2339..648c7d0e676cd85ea255557bd969d92659aeeca7 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/meta/type_traits.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -39,9 +40,8 @@ namespace xla { // Anyway this is relatively safe as-is because hlo_evaluator_typed_visitor.h is // a "private" header that's not exposed outside of hlo_evaluator.cc. template -using is_complex_t = std::is_same; -template -using is_complex64_t = std::is_same; +using is_complex_t = + absl::disjunction, std::is_same>; // It's UB to use std::sort with std::less, because of NaNs. Define // "safe" less functions which are actually strict weak orders. -NaN and NaN @@ -83,6 +83,26 @@ bool SafeLess(const NativeT& a, const NativeT& b) { return SafeLess(static_cast(a), static_cast(b)); } +// ToArithmeticSafeType(T t): +// - converts `t` to the bitwise-equivalent `unsigned T` if T is a signed +// integer, and +// - otherwise returns `t` unchanged. +// +// It's UB in C++ to under/overflow a signed integer, so we wrap all arithmetic +// in this type to force 2's complement behavior. +template ::value && + std::is_signed::value>::type* = nullptr> +typename std::make_unsigned::type ToArithmeticSafeType(T t) { + return static_cast::type>(t); +} +template ::value || + !std::is_signed::value>::type* = nullptr> +T ToArithmeticSafeType(T t) { + return std::move(t); +} + // Templated DfsHloVisitor for use by HloEvaluator. // // Typically ReturnT here indicates the resulting literal type of each evaluated @@ -192,7 +212,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, - typename std::enable_if::value>::type* = nullptr> + typename std::enable_if::value>::type* = nullptr> Status HandleAbs(HloInstruction* abs) { const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(abs->operand(0)); @@ -211,6 +231,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // specifying the ElementwiseT explicitly as C64 is needed below. if (abs->operand(0)->shape().element_type() == C64) { return HandleAbs(abs); + } else if (abs->operand(0)->shape().element_type() == C128) { + return HandleAbs(abs); } return HandleAbs(abs); } @@ -498,47 +520,25 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - template ::value && - !std::is_floating_point::value>::type* = nullptr> - Status HandleMultiply(HloInstruction* multiply) { - using type = typename std::make_unsigned::type; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return NativeT(type(lhs_elem) * type(rhs_elem)); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value || - std::is_floating_point::value || - is_complex_t::value>::type* = nullptr> - Status HandleMultiply(HloInstruction* multiply) { + Status HandleMultiply(HloInstruction* multiply) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return lhs_elem * rhs_elem; - })); + ElementWiseBinaryOp( + multiply, [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return ElementwiseT(ToArithmeticSafeType(lhs_elem) * + ToArithmeticSafeType(rhs_elem)); + })); return Status::OK(); } - Status HandleMultiply(HloInstruction* multiply) override { - return HandleMultiply(multiply); - } - Status HandleSubtract(HloInstruction* subtract) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[subtract], - ElementWiseBinaryOp(subtract, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return lhs_elem - rhs_elem; - })); + ElementWiseBinaryOp( + subtract, [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return ElementwiseT(ToArithmeticSafeType(lhs_elem) - + ToArithmeticSafeType(rhs_elem)); + })); return Status::OK(); } @@ -546,7 +546,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(parent_->evaluated_[add], ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return lhs_elem + rhs_elem; + return ElementwiseT(ToArithmeticSafeType(lhs_elem) + + ToArithmeticSafeType(rhs_elem)); })); return Status::OK(); } @@ -674,11 +675,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandlePower(HloInstruction* power) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[power], - ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return std::pow(lhs_el, rhs_el); - })); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[power], + ElementWiseBinaryOp( + power, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el == ElementwiseT(0) && rhs_el == ElementwiseT(0) + ? static_cast(1) + : std::pow(lhs_el, rhs_el); + })); return Status::OK(); } @@ -918,7 +922,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleClamp(HloInstruction* clamp) { std::function clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { - return std::fmin(high, std::fmax(value, low)); + if (std::isnan(low) || std::isnan(high)) { + return static_cast(NAN); + } + return static_cast( + std::fmin(high, std::fmax(value, low))); }; TF_ASSIGN_OR_RETURN( parent_->evaluated_[clamp], @@ -940,7 +948,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override { CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); - CHECK(ShapeUtil::IsArray(select->shape())); + CHECK(select->shape().IsArray()); std::function select_op = [](bool pred, ReturnT on_true, ReturnT on_false) { if (pred) { @@ -993,8 +1001,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape)); TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape)); - CHECK(ShapeUtil::IsArray(lhs_shape)); - CHECK(ShapeUtil::IsArray(rhs_shape)); + CHECK(lhs_shape.IsArray()); + CHECK(rhs_shape.IsArray()); CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape)); CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape)); @@ -1005,8 +1013,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { CHECK_GE(num_spatial_dims, 0); CHECK_EQ(window.dimensions_size(), num_spatial_dims); - const auto lhs_rank = ShapeUtil::Rank(lhs_shape); - const auto rhs_rank = ShapeUtil::Rank(rhs_shape); + const auto lhs_rank = lhs_shape.rank(); + const auto rhs_rank = rhs_shape.rank(); CHECK_EQ(num_spatial_dims + 2, lhs_rank); CHECK_EQ(num_spatial_dims + 2, rhs_rank); @@ -1037,15 +1045,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto lhs_literal_data = lhs_literal.data(); auto rhs_literal_data = rhs_literal.data(); - int64 feature_group_count = conv->feature_group_count(); - int64 batch_group_count = conv->batch_group_count(); + const int64 feature_group_count = conv->feature_group_count(); + const int64 batch_group_count = conv->batch_group_count(); - // The batch count > 1 case is unimplemented in the HLO evaluator so far. - TF_RET_CHECK(batch_group_count == 1); auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, - rhs_literal_data, - feature_group_count](const absl::Span out_index) { + rhs_literal_data, feature_group_count, + batch_group_count](const absl::Span out_index) { // Dimension number applicable for input (lhs). const int64 input_batch_dim = dnums.input_batch_dimension(); const int64 input_z_dim = dnums.input_feature_dimension(); @@ -1058,6 +1064,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 input_z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); + + const int64 input_batch_size = + ShapeUtil::GetDimension(lhs_shape, input_batch_dim); + + const int64 batch_group_size = input_batch_size / batch_group_count; + // The size of an input feature group. const int64 input_feature_group_size = input_z_size / feature_group_count; @@ -1073,11 +1085,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 feature_group_index = out_index[output_z_dim] / output_feature_group_size; + const int64 batch_group_index = out_index[output_z_dim]; + ElementwiseT result_val = static_cast(0); DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), 0); // Convolve input feature with kernel. + // The mechanism indexes into the correct LHS (input) and RHS (kernel) + // locations and accumulates multiplications for a given output index. do { // Find corresponding spatial dimension index for input (lhs). int64 lhs_linear_spatial_index = 0; @@ -1130,11 +1146,24 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { feature_group_index * input_feature_group_size + rhs_iz; int64 lhs_linear_index = lhs_linear_spatial_index; + lhs_linear_index += out_index[output_batch_dim] * lhs_dim_multipliers[input_batch_dim]; + + // We are scraping only the diagonal elements in the resultant + // convolution output when batch_group_count is greater than 1, + // where 1 is the default. No scraping is done in that case. + // This approach works out automatically for 'groups' in batches + // with group_size > 1, because we already descend down the batch + // dimension for the 'output_batch_dim' above. + lhs_linear_index += + ((batch_group_index * batch_group_size) % input_batch_size) * + lhs_dim_multipliers[input_batch_dim]; + lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; int64 rhs_linear_index = rhs_linear_spatial_index; + rhs_linear_index += out_index[output_z_dim] * rhs_dim_multipliers[kernel_output_z_dim]; rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim]; @@ -1158,7 +1187,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandleDot(HloInstruction* dot) override { - if (parent_->use_fast_path_) { + if (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() == 1 && + parent_->use_fast_path_) { return HandleDot(dot); } return HandleDotSlowPath(dot); @@ -1169,21 +1199,19 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleDot(HloInstruction* dot) { const HloInstruction* lhs = dot->operand(0); const HloInstruction* rhs = dot->operand(1); - CHECK(ShapeUtil::IsArray(dot->shape())); - CHECK(ShapeUtil::IsArray(lhs->shape())); - CHECK(ShapeUtil::IsArray(rhs->shape())); + CHECK(dot->shape().IsArray()); + CHECK(lhs->shape().IsArray()); + CHECK(rhs->shape().IsArray()); const auto& dnums = dot->dot_dimension_numbers(); - const int64 lhs_rank = ShapeUtil::Rank(lhs->shape()); - const int64 rhs_rank = ShapeUtil::Rank(rhs->shape()); + const int64 lhs_rank = lhs->shape().rank(); + const int64 rhs_rank = rhs->shape().rank(); CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); // There must be 1 and only 1 Contracting dimension for lhs and rhs. - CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); - CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); // Contracted dimension sizes must be the same. @@ -1232,33 +1260,18 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleDotSlowPath(HloInstruction* dot) { auto lhs = dot->operand(0); auto rhs = dot->operand(1); - CHECK(ShapeUtil::IsArray(dot->shape())); - CHECK(ShapeUtil::IsArray(lhs->shape())); - CHECK(ShapeUtil::IsArray(rhs->shape())); + CHECK(dot->shape().IsArray()); + CHECK(lhs->shape().IsArray()); + CHECK(rhs->shape().IsArray()); const auto& dnums = dot->dot_dimension_numbers(); - const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); - const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); + const auto lhs_rank = lhs->shape().rank(); + const auto rhs_rank = rhs->shape().rank(); CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); - // There must be 1 and only 1 Contracting dimension for lhs and rhs. - CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); - CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); - const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); - const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); - // Contracted dimension sizes must be the same. - CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension), - rhs->shape().dimensions(rhs_contracting_dimension)) - << "lhs contracted dimension: " - << lhs->shape().dimensions(lhs_contracting_dimension) - << " rhs contracted dimension: " - << rhs->shape().dimensions(rhs_contracting_dimension); - const int64 contracted_dimension_size = - lhs->shape().dimensions(lhs_contracting_dimension); - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); @@ -1272,7 +1285,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // in lhs_index or rhs_index where the i'th result index should go. absl::InlinedVector, kInlineRank> result_index_locations; - result_index_locations.reserve(lhs_rank + rhs_rank - 2); + result_index_locations.reserve( + (lhs_rank - dnums.lhs_contracting_dimensions_size()) + + (rhs_rank - dnums.rhs_contracting_dimensions_size())); // The first components in the output shape are the LHS and RHS batch // dimensions: @@ -1284,18 +1299,32 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Then we have the LHS and RHS non-contracting dimensions, if any: for (int64 i = 0; i < lhs_rank; i++) { - if (i != lhs_contracting_dimension && + if (!absl::c_linear_search(dnums.lhs_contracting_dimensions(), i) && !absl::c_linear_search(dnums.lhs_batch_dimensions(), i)) { result_index_locations.push_back({&lhs_index[i], nullptr}); } } for (int64 i = 0; i < rhs_rank; i++) { - if (i != rhs_contracting_dimension && + if (!absl::c_linear_search(dnums.rhs_contracting_dimensions(), i) && !absl::c_linear_search(dnums.rhs_batch_dimensions(), i)) { result_index_locations.push_back({&rhs_index[i], nullptr}); } } + absl::InlinedVector accumulate_index_sizes; + accumulate_index_sizes.reserve(dnums.lhs_contracting_dimensions_size()); + absl::InlinedVector, kInlineRank> + accumulate_index_locations; + accumulate_index_locations.reserve(dnums.lhs_contracting_dimensions_size()); + for (int64 i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) { + const int64 lhs_dnum = dnums.lhs_contracting_dimensions(i); + const int64 rhs_dnum = dnums.rhs_contracting_dimensions(i); + accumulate_index_locations.push_back( + {&lhs_index[lhs_dnum], &rhs_index[rhs_dnum]}); + const int64 dim_size = lhs->shape().dimensions(lhs_dnum); + accumulate_index_sizes.push_back(dim_size); + } + const int64 total_contraction_size = Product(accumulate_index_sizes); Literal result(dot->shape()); TF_RETURN_IF_ERROR( result.Populate([&](absl::Span result_index) { @@ -1309,13 +1338,30 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } // Accumulates resulting product along the contracted dimension. - for (int64 i = 0; i < contracted_dimension_size; ++i) { - lhs_index[lhs_contracting_dimension] = i; - rhs_index[rhs_contracting_dimension] = i; + absl::InlinedVector accumulate_index( + accumulate_index_sizes.size(), 0); + for (int64 k = 0; k < total_contraction_size; k++) { + for (int64 i = 0; i < accumulate_index_sizes.size(); ++i) { + *(accumulate_index_locations[i].first) = accumulate_index[i]; + *(accumulate_index_locations[i].second) = accumulate_index[i]; + } result_val += static_cast(lhs_literal.Get(lhs_index)) * static_cast(rhs_literal.Get(rhs_index)); + + // If there are no contracting dimension accumulate_index_sizes is + // empty, do not try to count down from -1 to 0 since it is and + // infinite loop. + if (!accumulate_index_sizes.empty()) { + for (int64 i = accumulate_index_sizes.size() - 1; i >= 0; --i) { + int64 value = ++accumulate_index[i]; + if (value != accumulate_index_sizes[i]) { + break; + } + accumulate_index[i] = 0; + } + } } return static_cast(result_val); @@ -1326,10 +1372,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandlePad(HloInstruction* pad) override { - CHECK(ShapeUtil::IsArray(pad->operand(0)->shape())); + CHECK(pad->operand(0)->shape().IsArray()); // Padding value must be scalar. CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); - CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), + CHECK_EQ(pad->operand(0)->shape().rank(), pad->padding_config().dimensions_size()); TF_ASSIGN_OR_RETURN(auto inferred_return_shape, @@ -1352,9 +1398,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& evaluated_operand = parent_->GetEvaluatedLiteralFor(pad->operand(0)); - std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), - 0); - std::vector target_index(ShapeUtil::Rank(result.shape()), 0); + std::vector input_index(evaluated_operand.shape().rank(), 0); + std::vector target_index(result.shape().rank(), 0); // Loop through each element of the operand, assign them to the // corresponding index of the resulting padded literal. @@ -1397,10 +1442,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto operand = dynamic_slice->operand(0); auto start_indices = dynamic_slice->operand(1); auto result_shape = dynamic_slice->shape(); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferDynamicSliceShape( - operand->shape(), start_indices->shape(), - dynamic_slice->dynamic_slice_sizes())); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferDynamicSliceShape( + operand->shape(), + Cast(dynamic_slice)->index_shapes(), + dynamic_slice->dynamic_slice_sizes())); TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(result_shape) << " but is inferred to be: " @@ -1409,33 +1456,39 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { primitive_util::IsIntegralType(start_indices->shape().element_type())); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - const Literal& start_indices_literal = - parent_->GetEvaluatedLiteralFor(start_indices); switch (start_indices->shape().element_type()) { case S32: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); + DynamicSlice( + operand_literal, + absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), + result_shape)); } break; case S64: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); + DynamicSlice( + operand_literal, + absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), + result_shape)); } break; case U32: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); + DynamicSlice( + operand_literal, + absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), + result_shape)); } break; case U64: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); + DynamicSlice( + operand_literal, + absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), + result_shape)); } break; default: LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for " @@ -1455,7 +1508,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN( auto inferred_return_shape, ShapeInference::InferDynamicUpdateSliceShape( - operand->shape(), update->shape(), start_indices->shape())); + operand->shape(), update->shape(), + Cast(dynamic_update_slice) + ->index_shapes())); TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(result_shape) << " but is inferred to be: " @@ -1466,33 +1521,39 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update); - const Literal& start_indices_literal = - parent_->GetEvaluatedLiteralFor(start_indices); switch (start_indices->shape().element_type()) { case S32: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); + DynamicUpdateSlice( + operand_literal, update_literal, + absl::MakeConstSpan(dynamic_update_slice->operands()) + .subspan(2))); } break; case S64: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); + DynamicUpdateSlice( + operand_literal, update_literal, + absl::MakeConstSpan(dynamic_update_slice->operands()) + .subspan(2))); } break; case U32: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); + DynamicUpdateSlice( + operand_literal, update_literal, + absl::MakeConstSpan(dynamic_update_slice->operands()) + .subspan(2))); } break; case U64: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); + DynamicUpdateSlice( + operand_literal, update_literal, + absl::MakeConstSpan(dynamic_update_slice->operands()) + .subspan(2))); } break; default: LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for " @@ -1529,7 +1590,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Literal computed_result = - embedded_evaluator.Evaluate(*computation, arg_literals) + embedded_evaluator.Evaluate(*computation, arg_literals) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on // the same computation. @@ -1587,6 +1648,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); break; } + case C128: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } default: LOG(FATAL) << "HandleMap: unhandled primitive type for " "input operand: " @@ -1609,7 +1674,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys); int64 sort_dim = sort->dimensions(0); int64 sort_dim_elements = keys->shape().dimensions(sort_dim); - int64 rank = ShapeUtil::Rank(keys->shape()); + int64 rank = keys->shape().rank(); if (rank == 0) { // Nothing to sort. parent_->evaluated_[sort] = keys_literal.Clone(); @@ -1626,8 +1691,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Extract a slice from the literal that corresponds to exactly the // row in dimension 'sort_dim'. std::vector limit_indices(indices.begin(), indices.end()); - std::for_each(limit_indices.begin(), limit_indices.end(), - [](int64& index) { ++index; }); + absl::c_for_each(limit_indices, [](int64& index) { ++index; }); limit_indices[sort_dim] = sort_dim_elements; TF_ASSIGN_OR_RETURN(auto row_to_sort, keys_literal.Slice(indices, limit_indices) @@ -1670,7 +1734,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleReduce(HloInstruction* hlo) override { HloReduceInstruction* reduce = Cast(hlo); int64 num_args = reduce->inputs().size(); - bool has_tuple_output = ShapeUtil::IsTuple(reduce->shape()); + bool has_tuple_output = reduce->shape().IsTuple(); absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); @@ -1701,7 +1765,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // All args and results have the same dimensions, so pick an arbitrary one. const Shape& arg_shape = arg_literals[0]->shape(); - const Shape& result_shape = ShapeUtil::IsTuple(reduce->shape()) + const Shape& result_shape = reduce->shape().IsTuple() ? reduce->shape().tuple_shapes(0) : reduce->shape(); const auto arg_dimensions = AsInt64Slice(arg_shape.dimensions()); @@ -1790,7 +1854,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { [](Literal& literal) { return &literal; }); TF_ASSIGN_OR_RETURN(Literal computed_result, - embedded_evaluator.Evaluate( + embedded_evaluator.Evaluate( *function, embedded_operands_ptrs)); // Clear visit states so that we can use the evaluator again on // the same computation. @@ -1868,7 +1932,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source); - int64 rank = ShapeUtil::Rank(operand_literal.shape()); + int64 rank = operand_literal.shape().rank(); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); DimensionVector source_index(rank, 0); @@ -1906,8 +1970,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { selected_val_literal.Set({}, *selected_val); Literal computed_result = embedded_evaluator - .Evaluate( - *select, {&selected_val_literal, &curr_val_literal}) + .Evaluate(*select, + {&selected_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); bool selected = !computed_result.Get({}); if (selected) { @@ -1928,9 +1992,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { scattered_literal.Set({}, scattered); Literal computed_result = embedded_evaluator - .Evaluate( - *scatter, - {&source_literal_scatter, &scattered_literal}) + .Evaluate(*scatter, + {&source_literal_scatter, &scattered_literal}) .ConsumeValueOrDie(); result.Set(operand_index, computed_result.Get({})); // Clear visit states so that the we can use the evaluator again @@ -1980,7 +2043,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { operand->shape().element_type(), window_dimension_sizes); DimensionVector window_index(window.dimensions_size()); - DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); + DimensionVector operand_index(operand_literal.shape().rank()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); Literal result(reduce_window->shape()); @@ -2004,8 +2067,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { LiteralUtil::CreateR0(result_val); Literal computed_result = embedded_evaluator - .Evaluate( - *function, {&result_val_literal, &curr_val_literal}) + .Evaluate(*function, + {&result_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again @@ -2367,9 +2430,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { LiteralUtil::CreateR0(updates.Get(update_index)); Literal updated_result = embedded_evaluator - .Evaluate( - *scatter->to_apply(), - {&result_value_literal, &update_value_literal}) + .Evaluate(*scatter->to_apply(), + {&result_value_literal, &update_value_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on the // same computation. @@ -2411,7 +2473,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); - const int64 rank = ShapeUtil::Rank(operand->shape()); + const int64 rank = operand->shape().rank(); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); auto func = [&](absl::Span out_index) { DimensionVector operand_index(rank); @@ -2608,7 +2670,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Double not supported for reduce precision"); + return InvalidArgument("Double is not supported for reduce precision"); } template < @@ -2623,12 +2685,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleReducePrecision(reduce_precision); } - template ::value || - std::is_same::value || - std::is_integral::value || - std::is_floating_point::value>::type* = nullptr> + template < + typename NativeT, + typename std::enable_if< + std::is_same::value || + std::is_same::value || + std::is_integral::value || is_complex_t::value || + std::is_floating_point::value>::type* = nullptr> Status HandleIota(HloInstruction* instruction) { auto* iota = Cast(instruction); const int64 iota_size = iota->shape().dimensions(iota->iota_dimension()); @@ -2648,23 +2711,24 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } auto result = LiteralUtil::CreateR1(data); - if (ShapeUtil::Rank(iota->shape()) > 1) { + if (iota->shape().rank() > 1) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[iota], result.Broadcast(iota->shape(), {iota->iota_dimension()})); } else { - TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1); + TF_RET_CHECK(iota->shape().rank() == 1); parent_->evaluated_[iota] = std::move(result); } return Status::OK(); } - template ::value || - std::is_same::value || - std::is_integral::value || - std::is_floating_point::value)>::type* = nullptr> + template < + typename NativeT, + typename std::enable_if< + !(std::is_same::value || + std::is_same::value || + std::is_integral::value || is_complex_t::value || + std::is_floating_point::value)>::type* = nullptr> Status HandleIota(HloInstruction* iota) { return UnsupportedTypeError(iota); } @@ -2672,6 +2736,103 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleIota(iota); } + template ::value || + std::is_floating_point::value)>::type* = nullptr> + Status HandleRng(HloInstruction* random) { + return UnsupportedTypeError(random); + } + template ::value)>::type* = nullptr> + Status HandleRng(HloInstruction* random) { + RandomDistribution distribution = random->random_distribution(); + const auto result_shape = random->shape(); + Literal result(result_shape); + + switch (distribution) { + case RNG_UNIFORM: { + const Literal& low = + parent_->GetEvaluatedLiteralFor(random->operand(0)); + const Literal& high = + parent_->GetEvaluatedLiteralFor(random->operand(1)); + + std::uniform_real_distribution generator( + low.Get({}), high.Get({})); + + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span /*indexes*/) { + return generator(parent_->engine_); + })); + break; + } + case RNG_NORMAL: { + const Literal& mean = + parent_->GetEvaluatedLiteralFor(random->operand(0)); + const Literal& stddev = + parent_->GetEvaluatedLiteralFor(random->operand(1)); + + std::normal_distribution generator(mean.Get({}), + stddev.Get({})); + + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span /*indexes*/) { + return generator(parent_->engine_); + })); + break; + } + default: + return UnimplementedStrCat("The distribution ", + RandomDistribution_Name(distribution), + " is not implemented."); + } + parent_->evaluated_[random] = std::move(result); + return Status::OK(); + } + template ::value)>::type* = + nullptr> + Status HandleRng(HloInstruction* random) { + RandomDistribution distribution = random->random_distribution(); + const auto result_shape = random->shape(); + Literal result(result_shape); + + switch (distribution) { + case RNG_UNIFORM: { + const Literal& low = + parent_->GetEvaluatedLiteralFor(random->operand(0)); + const Literal& high = + parent_->GetEvaluatedLiteralFor(random->operand(1)); + + // Note std::uniform_int_distribution assumes interval is closed, i.e., + // [low, high], but we want [low, high) instead. Hence high-1 is used as + // the upper range. + std::uniform_int_distribution generator( + low.Get({}), high.Get({}) - 1); + + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span /*indexes*/) { + return static_cast(generator(parent_->engine_)); + })); + break; + } + case RNG_NORMAL: { + return Unimplemented( + "Normal distribution is not supported for integral types."); + } + default: + return UnimplementedStrCat("The distribution ", + RandomDistribution_Name(distribution), + " is not implemented."); + } + parent_->evaluated_[random] = std::move(result); + return Status::OK(); + } + Status HandleRng(HloInstruction* random) override { + return HandleRng(random); + } + private: // Creates a vector of multipliers which can be used to create a linear index // into shape. @@ -2683,7 +2844,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // // This lets you calculate LI given the multidimensional indices in any order. static DimensionVector MakeDimMultipliers(const Shape& shape) { - DimensionVector v(ShapeUtil::Rank(shape)); + DimensionVector v(shape.rank()); int64 scale = 1; for (auto dim : LayoutUtil::MinorToMajor(shape)) { v[dim] = scale; @@ -2700,7 +2861,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Shape& window_shape, const Window& window, const Shape& base_shape, const absl::Span& window_count_index, const std::function&)>& f) { - const int64 rank = ShapeUtil::Rank(base_shape); + const int64 rank = base_shape.rank(); DimensionVector window_index(rank); std::fill(window_index.begin(), window_index.end(), 0); do { @@ -2731,12 +2892,27 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr DynamicSlice(const Literal& operand_literal, - const Literal& start_indices_literal, - const Shape& result_shape) { - auto start_indices_typed = start_indices_literal.data(); - std::vector start(start_indices_typed.begin(), - start_indices_typed.end()); + StatusOr DynamicSlice( + const Literal& operand_literal, + absl::Span start_indices, + const Shape& result_shape) { + std::vector start; + // TODO(b/118437727): Remove the R1 code-path. Note that to distinguish + // between the cases, this currently assumes there is at least 1 index. That + // is wrong in the general case, because for scalar indices, if the operand + // is scalar, then there are no indices. This problem with resolve itself. + const HloInstruction* first_index = start_indices[0]; + if (first_index->shape().rank() == 1) { + auto start_indices_typed = + parent_->GetEvaluatedLiteralFor(first_index).data(); + start = std::vector(start_indices_typed.begin(), + start_indices_typed.end()); + } else { + for (HloInstruction* index : start_indices) { + start.push_back( + parent_->GetEvaluatedLiteralFor(index).GetFirstElement()); + } + } // Clamp the start indices so the slice is in-bounds w.r.t the operand. for (int64 i = 0; i < start.size(); ++i) { @@ -2762,14 +2938,28 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr DynamicUpdateSlice(const Literal& operand_literal, - const Literal& update_literal, - const Literal& start_indices_literal) { + StatusOr DynamicUpdateSlice( + const Literal& operand_literal, const Literal& update_literal, + absl::Span start_indices) { auto result = operand_literal.Clone(); - auto start_indices_typed = start_indices_literal.data(); - const auto rank = ShapeUtil::Rank(result.shape()); - std::vector start(start_indices_typed.begin(), - start_indices_typed.end()); + const auto rank = result.shape().rank(); + std::vector start; + // TODO(b/118437727): Remove the R1 code-path. Note that to distinguish + // between the cases, this currently assumes there is at least 1 index. That + // is wrong in the general case, because for scalar indices, if the operand + // is scalar, then there are no indices. This problem with resolve itself. + const HloInstruction* first_index = start_indices[0]; + if (first_index->shape().rank() == 1) { + auto start_indices_typed = + parent_->GetEvaluatedLiteralFor(first_index).data(); + start = std::vector(start_indices_typed.begin(), + start_indices_typed.end()); + } else { + for (HloInstruction* index : start_indices) { + start.push_back( + parent_->GetEvaluatedLiteralFor(index).GetFirstElement()); + } + } // Clamp the update start indices so the slice is in-bounds w.r.t the // operand. for (int64 i = 0; i < rank; ++i) { @@ -2886,6 +3076,7 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex128.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex128.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f48140ee4f6ca9415bef80c83664213109dbf9f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex128.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int16.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int16.cc new file mode 100644 index 0000000000000000000000000000000000000000..e54285a1577a3f3c97fba5ba6c2f969299ab599e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int16.cc @@ -0,0 +1,22 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint16.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint16.cc new file mode 100644 index 0000000000000000000000000000000000000000..cc708952d20a00429944c8388a84a0e610c2f38f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint16.cc @@ -0,0 +1,22 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index 5be9dba3aa49d63c580cd6f5800f608667826b6a..df06cf8c53ec8407f8b44c9126ed4fb5409f8ef3 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -45,7 +45,7 @@ TEST_F(HloExecutionProfileTest, Basic) { auto shape_size_function = [&](const Shape& shape) { const int64 pointer_size = 8; - if (ShapeUtil::IsOpaque(shape)) { + if (shape.IsOpaque()) { return pointer_size; } return ShapeUtil::ByteSizeOf(shape, pointer_size); diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc index c919dbd82d3668c477bf37074f1d56f8cb7d9506..862b2029718bbd802b69d789b66683a4edfa2367 100644 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -25,7 +26,9 @@ namespace xla { namespace { -StatusOr ReplaceGetSize(HloInstruction* instr) { +StatusOr ReplaceGetSize( + HloInstruction* instr, + const DynamicDimensionInference* dynamic_dimension_inference) { if (instr->opcode() != HloOpcode::kGetDimensionSize) { return false; } @@ -36,10 +39,18 @@ StatusOr ReplaceGetSize(HloInstruction* instr) { instr->operand(0)->shape(), instr->dimension())); TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape)); TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), U32)); - uint32 size = instr->operand(0)->shape().dimensions(instr->dimension()); - HloInstruction* new_instr = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); - TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); + HloInstruction* operand = instr->mutable_operand(0); + int64 dim = instr->dimension(); + HloInstruction* dynamic_size = + dynamic_dimension_inference->GetDynamicSize(operand, {}, dim); + if (dynamic_size != nullptr) { + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size)); + } else { + uint32 size = instr->operand(0)->shape().dimensions(dim); + HloInstruction* new_instr = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); + } return true; } @@ -48,10 +59,13 @@ StatusOr ReplaceGetSize(HloInstruction* instr) { StatusOr HloGetDimensionSizeRewriter::Run(HloModule* module) { bool changed = false; HloProto proto; + TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference, + DynamicDimensionInference::Run(module)); *proto.mutable_hlo_module() = module->ToProto(); for (auto* computation : module->computations()) { for (auto instruction : computation->instructions()) { - TF_ASSIGN_OR_RETURN(bool replaced, ReplaceGetSize(instruction)); + TF_ASSIGN_OR_RETURN(bool replaced, + ReplaceGetSize(instruction, &inference)); changed = changed || replaced; } } diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h index 30f44c23a835b3bcc935caaa917e040e07c4e703..9aa79fe66b665c48ec871c4188e44ba2056de3ad 100644 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h @@ -21,7 +21,9 @@ limitations under the License. namespace xla { -// Pass to replace a kGetDimensionSize instruction with a constant instruction. +// Pass to replace a kGetDimensionSize instruction with a hlo instruction +// representing the dynamic size if the dimension is dynamic, otherwise a +// constant instruction representing the static size. class HloGetDimensionSizeRewriter : public HloModulePass { public: absl::string_view name() const override { diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index dbf0d2c113bf670da3617967d913da819ccf2663..4c7f5e9e7dfb12a8cb699bdf397eab21983342a1 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -24,9 +24,9 @@ limitations under the License. #include #include #include -#include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -380,7 +380,7 @@ class HloDotDumper { // Each HloInstruction dumped gets a monotically-increasing node ID. This // must start at 1, because that's where graphviz's accounting starts. int64 next_node_id_ = 1; - std::unordered_map node_ids_; + absl::flat_hash_map node_ids_; // The "root" tag doesn't have an associated HloInstruction pointer, so we // need to store it outside the map. @@ -397,7 +397,7 @@ class HloDotDumper { // Each HloComputation that's emitted gets a monotonically-increasing ID. int64 next_cluster_id_ = 1; - std::unordered_map cluster_ids_; + absl::flat_hash_map cluster_ids_; // Edges to print from Footer(). Edges come at the end because graphviz is // unhappy if an edge from a subcomputation to a node in the outer computation @@ -407,7 +407,7 @@ class HloDotDumper { // When coloring by sharding information, we track the sharding string // representation to color association, by round-robin the color schemes. - std::unordered_map + absl::flat_hash_map sharding_colors_; int64 next_shard_color_ = 0; }; @@ -561,8 +561,8 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { } // Show the subcomputation if we're showing any of its members. - return std::any_of( - subcomp->instructions().begin(), subcomp->instructions().end(), + return absl::c_any_of( + subcomp->instructions(), [&](const HloInstruction* instr) { return filter_.Show(instr); }); } @@ -733,17 +733,16 @@ bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { return true; } const int kMinUsersToOmit = 3; - return instr->opcode() == HloOpcode::kParameter && - ShapeUtil::IsTuple(instr->shape()) && !instr->IsFused() && - std::count_if(instr->users().begin(), instr->users().end(), - [&](const HloInstruction* user) { - return filter_.Show(user); - }) > kMinUsersToOmit && - std::all_of(instr->users().begin(), instr->users().end(), - [&](const HloInstruction* user) { - return !filter_.Show(user) || - user->opcode() == HloOpcode::kGetTupleElement; - }); + return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() && + !instr->IsFused() && + absl::c_count_if(instr->users(), + [&](const HloInstruction* user) { + return filter_.Show(user); + }) > kMinUsersToOmit && + absl::c_all_of(instr->users(), [&](const HloInstruction* user) { + return !filter_.Show(user) || + user->opcode() == HloOpcode::kGetTupleElement; + }); } string HloDotDumper::DumpInstruction(const HloInstruction* instr) { @@ -816,7 +815,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // Print the literal value of constants with <= K elements. optional elem_count; - if (ShapeUtil::IsArray(shape)) { + if (shape.IsArray()) { elem_count = 1; for (int64 dim : shape.dimensions()) { *elem_count *= dim; @@ -900,12 +899,11 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { // the same color as a parameter. Unless the merged-in parameter is a // parameter to a fusion node that is bound to a constant -- these aren't // "real" parameters from the user's perspective. - if (std::any_of(instr->operands().begin(), instr->operands().end(), - [&](const HloInstruction* operand) { - return operand->opcode() == HloOpcode::kParameter && - ShouldMergeIntoUsers(operand) && - TryGetFusionParameterConstant(operand) == nullptr; - })) { + if (absl::c_any_of(instr->operands(), [&](const HloInstruction* operand) { + return operand->opcode() == HloOpcode::kParameter && + ShouldMergeIntoUsers(operand) && + TryGetFusionParameterConstant(operand) == nullptr; + })) { return parameter_color; } @@ -1286,7 +1284,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, int64 radius) { // First, find the neighborhood of nodes with distance from root <= radius. // These nodes are our initial set of "normal" nodes. - std::unordered_map nodes; + absl::flat_hash_map nodes; std::deque> worklist; worklist.push_back({root, 0}); while (!worklist.empty()) { @@ -1307,7 +1305,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, // are not interesting to the graph at hand. if (instr == root || instr->opcode() != HloOpcode::kTuple) { for (const HloInstruction* operand : instr->operands()) { - if (!nodes.count(operand)) { + if (!nodes.contains(operand)) { worklist.push_back({operand, depth + 1}); } } @@ -1335,7 +1333,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, continue; } for (const HloInstruction* user : instr->users()) { - if (!nodes.count(user)) { + if (!nodes.contains(user)) { worklist.push_back({user, depth + 1}); } } @@ -1344,7 +1342,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, auto is_displayed = [&](const HloInstruction* instr) { // Constants are displayed inline with their users; they're never omitted. // Nodes in subcomputations are always shown. - return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant || + return nodes.contains(instr) || instr->opcode() == HloOpcode::kConstant || instr->parent() != root->parent(); }; @@ -1355,12 +1353,11 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, NodeFilterResult& filter_result = kv.second; const auto& operands = instr->operands(); - if (std::any_of(operands.begin(), operands.end(), is_displayed) && - !std::all_of(operands.begin(), operands.end(), is_displayed)) { + if (absl::c_any_of(operands, is_displayed) && + !absl::c_all_of(operands, is_displayed)) { // Mark nodes with some operands omitted appropriately. filter_result = kSomeOperandsOmitted; - } else if (!operands.empty() && - std::none_of(operands.begin(), operands.end(), is_displayed)) { + } else if (!operands.empty() && absl::c_none_of(operands, is_displayed)) { // Mark nodes with *all* operands omitted appropriately. filter_result = kOmitNodeOperands; } @@ -1368,8 +1365,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their // users made it into the graph. if (filter_result == kSomeUsersOmitted && - std::all_of(instr->users().begin(), instr->users().end(), - is_displayed)) { + absl::c_all_of(instr->users(), is_displayed)) { filter_result = kNormalNode; } } diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc index 6e1597fd03db0a78aa560340b7b9b64fe500df0c..b01c00121b3363630b83a1e49d0027a66f3a9e1a 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc @@ -17,22 +17,34 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" namespace xla { + +bool HloInputOutputAliasConfig::OutputHasAlias( + const ShapeIndex& output_index) const { + return alias_.element(output_index).has_value(); +} + Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index) { + const ShapeIndex& param_index, + AliasKind kind) { + TF_RET_CHECK(kind == AliasKind::kUserAlias || kind == AliasKind::kSystemAlias) + << kind; TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)) << absl::StrCat("Tring to set up alias at ", output_index.ToString(), " which is an invalid index for shape ", ShapeUtil::HumanString(alias_.shape())); + TF_RET_CHECK(param_number >= 0) << param_number; + TF_RET_CHECK(!OutputHasAlias(output_index)) + << "Output index " << output_index << " already has an alias setup"; // Output can't be aliased with multiple parameters. TF_RET_CHECK(!alias_.element(output_index)) << absl::StrFormat( "Trying to set up output alias for param %lld at %s but failed: output " "index %s is already aliased with param %lld at %s", param_number, param_index.ToString(), output_index.ToString(), - alias_.element(output_index)->first, - alias_.element(output_index)->second.ToString()); + alias_.element(output_index)->parameter_number, + alias_.element(output_index)->parameter_index.ToString()); (*alias_.mutable_element(output_index)) = - std::make_pair(param_number, param_index); + Alias(kind, param_number, param_index); VLOG(4) << "Set up alias between output index " << output_index.ToString() << " and parameter " << param_index << " at index " << param_index.ToString(); @@ -42,15 +54,24 @@ Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index, HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const { HloInputOutputAliasProto result; alias_.ForEachElement( - [&](const ShapeIndex& index, - const absl::optional>& data) { + [&](const ShapeIndex& index, const absl::optional& data) { if (data) { HloInputOutputAliasProto::AliasEntryProto entry; + switch (data->kind) { + case AliasKind::kUserAlias: + entry.set_kind(HloInputOutputAliasProto::USER_ALIAS); + break; + case AliasKind::kSystemAlias: + entry.set_kind(HloInputOutputAliasProto::SYSTEM_ALIAS); + break; + default: + LOG(FATAL) << "Unknown alias kind " << data->kind; + } for (int64 i : index) { entry.add_output_shape_index(i); } - entry.set_parameter_number(data->first); - for (int64 i : data->second) { + entry.set_parameter_number(data->parameter_number); + for (int64 i : data->parameter_index) { entry.add_parameter_shape_index(i); } result.add_entries()->Swap(&entry); @@ -66,14 +87,18 @@ StatusOr HloInputOutputAliasConfig::CreateFromProto( proto.entries()) { ShapeIndex output_index(entry.output_shape_index().begin(), entry.output_shape_index().end()); - int64 param_number = entry.parameter_number(); ShapeIndex param_index(entry.parameter_shape_index().begin(), entry.parameter_shape_index().end()); + // Handle backward compatibility with existing protos, which only knew of + // system aliases. + AliasKind kind = AliasKind::kSystemAlias; + if (entry.kind() == HloInputOutputAliasProto::USER_ALIAS) { + kind = AliasKind::kUserAlias; + } TF_RETURN_IF_ERROR( - result.SetUpAlias(output_index, param_number, param_index)); + result.SetUpAlias(output_index, param_number, param_index, kind)); } - return result; } @@ -81,45 +106,44 @@ string HloInputOutputAliasConfig::ToString() const { std::vector pieces; pieces.push_back("HloInputOutputAliasConfig"); - ForEachAlias([&](const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index) { + ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) { + const char* kind = alias.kind == AliasKind::kUserAlias ? "USER" : "SYSTEM"; pieces.push_back(absl::StrFormat( - " OutputIndex %s is aliased with parameter %lld at %s:", - output_index.ToString(), param_number, param_index.ToString())); + " OutputIndex %s is aliased (kind=%s) with parameter %lld at %s:", + output_index.ToString(), kind, alias.parameter_number, + alias.parameter_index.ToString())); }); - return absl::StrJoin(pieces, "\n"); } -bool HloInputOutputAliasConfig::ParameterHasAlias( +HloInputOutputAliasConfig::AliasKind +HloInputOutputAliasConfig::ParameterAliasKind( int64 param_number, const ShapeIndex& param_index) const { - bool output = false; + AliasKind kind = AliasKind::kNoAlias; alias_.ForEachElement( - [&](const xla::ShapeIndex&, - absl::optional> alias) { - if (alias && alias->first == param_number && - alias->second == param_index) { - output = true; + [&](const xla::ShapeIndex&, absl::optional alias) { + if (alias && alias->parameter_number == param_number && + alias->parameter_index == param_index) { + kind = alias->kind; } }); - return output; + return kind; } absl::optional HloInputOutputAliasConfig::GetAliasedOutput( int64 param_number, const ShapeIndex& param_index) const { absl::optional output; alias_.ForEachElement( - [&](const xla::ShapeIndex& output_index, - absl::optional> alias) { - if (alias && alias->first == param_number && - alias->second == param_index) { + [&](const xla::ShapeIndex& output_index, absl::optional alias) { + if (alias && alias->parameter_number == param_number && + alias->parameter_index == param_index) { output = output_index; } }); return output; } -absl::optional> +absl::optional HloInputOutputAliasConfig::GetAliasedParameter( const ShapeIndex& output_index) const { CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)); @@ -128,10 +152,9 @@ HloInputOutputAliasConfig::GetAliasedParameter( void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const { alias_.ForEachElement( - [&](const ShapeIndex& output_index, - absl::optional> aliased) { + [&](const ShapeIndex& output_index, absl::optional aliased) { if (aliased) { - fn(output_index, aliased->first, aliased->second); + fn(output_index, *aliased); } }); } @@ -139,10 +162,9 @@ void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const { Status HloInputOutputAliasConfig::ForEachAliasWithStatus( AliasFnWithStatus fn) const { return alias_.ForEachElementWithStatus( - [&](const ShapeIndex& output_index, - absl::optional> aliased) { + [&](const ShapeIndex& output_index, absl::optional aliased) { if (aliased) { - TF_RETURN_IF_ERROR(fn(output_index, aliased->first, aliased->second)); + TF_RETURN_IF_ERROR(fn(output_index, *aliased)); } return Status::OK(); }); @@ -158,20 +180,19 @@ Status HloInputOutputAliasConfig::Verify( param_has_seen.emplace_back(param->shape()); } return ForEachAliasWithStatus([&](const ShapeIndex& output_index, - int64 param_number, - const ShapeIndex& param_index) -> Status { + const Alias& alias) -> Status { const HloInstruction* root = entry->root_instruction(); - TF_RET_CHECK(0 <= param_number); - TF_RET_CHECK(entry->num_parameters() > param_number); + TF_RET_CHECK(0 <= alias.parameter_number); + TF_RET_CHECK(entry->num_parameters() > alias.parameter_number); const Shape& param_shape = - entry->parameter_instruction(param_number)->shape(); + entry->parameter_instruction(alias.parameter_number)->shape(); const Shape& output_shape = root->shape(); - TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, param_index)); + TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, alias.parameter_index)); TF_RET_CHECK(ShapeUtil::IndexIsValid(output_shape, output_index)); const Shape& param_subshape = - ShapeUtil::GetSubshape(param_shape, param_index); + ShapeUtil::GetSubshape(param_shape, alias.parameter_index); const Shape& output_subshape = ShapeUtil::GetSubshape(output_shape, output_index); TF_RET_CHECK(LayoutUtil::IsDenseArray(param_subshape)); @@ -182,19 +203,20 @@ Status HloInputOutputAliasConfig::Verify( "Expected aliased input %lld at index %s and output at index %s to " "have the same size. Input sub-shape is %s with size %lld, output " "sub-shape is %s with size %lld", - param_number, param_index.ToString(), output_index.ToString(), + alias.parameter_number, alias.parameter_index.ToString(), + output_index.ToString(), ShapeUtil::HumanStringWithLayout(param_subshape), size_func(param_subshape), ShapeUtil::HumanStringWithLayout(output_subshape), size_func(output_subshape)); } - // Check each param_number and param_index pair only show up once. No - // input can be aliased with output buffers. - TF_RET_CHECK(param_has_seen[param_number].element(param_index) == false); - - *(param_has_seen[param_number].mutable_element(param_index)) = true; - + // Check each alias.parameter_number and alias.parameter_index pair only + // show up once. No input can be aliased with output buffers. + TF_RET_CHECK(param_has_seen[alias.parameter_number].element( + alias.parameter_index) == false); + *(param_has_seen[alias.parameter_number].mutable_element( + alias.parameter_index)) = true; return Status::OK(); }); } diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h index 439676b1546c4af7f781fb80bccffd5248309b0f..b0b71dece81b561f492767db8c1ccbe3fde442d4 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_tree.h" @@ -31,6 +32,28 @@ class HloModule; // parameter index in the entry computation. class HloInputOutputAliasConfig { public: + // The kind of aliases which can be set. A kUserAlias is one setup at + // compilation time by the user, and has to be respected. A kSystemAlias one + // might be setup by the compiler, if it decides it is convenient to do so. + enum AliasKind { + kNoAlias, + kUserAlias, + kSystemAlias, + }; + + // Defines the alias information for a given output buffer. A given output + // buffer shape index can refer only to one parameter+index. + struct Alias { + Alias(AliasKind kind, int64 parameter_number, ShapeIndex parameter_index) + : kind(kind), + parameter_number(parameter_number), + parameter_index(std::move(parameter_index)) {} + + AliasKind kind; + int64 parameter_number; + ShapeIndex parameter_index; + }; + HloInputOutputAliasConfig() = default; explicit HloInputOutputAliasConfig(Shape shape) : alias_(shape) {} @@ -40,12 +63,22 @@ class HloInputOutputAliasConfig { // Sets up alias config from `output_index` to `param_index` at // `param_number`. Status SetUpAlias(const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index); + const ShapeIndex& param_index, AliasKind kind); + + // Returns the kind of alias for the given parameter number and parameter + // index. If no alias exists, AliasKind::kNoAlias is returned. + AliasKind ParameterAliasKind(int64 param_number, + const ShapeIndex& param_index) const; // Returns true if the given parameter is aliased with one of the output // buffers. bool ParameterHasAlias(int64 param_number, - const ShapeIndex& param_index) const; + const ShapeIndex& param_index) const { + return ParameterAliasKind(param_number, param_index) != AliasKind::kNoAlias; + } + + // Checks whether the provided output index has already been aliased. + bool OutputHasAlias(const ShapeIndex& output_index) const; // (De)Serializes an HloInputOutoutAliasConfig to/from an // HloInputOutoutAliasProto. @@ -63,19 +96,17 @@ class HloInputOutputAliasConfig { // Returns the number of parameter and index of the parameter buffer that the // given output buffer index is aliased with. A nullopt is returned if there // is no parameter is aliased with the specific output. - absl::optional> GetAliasedParameter( + absl::optional GetAliasedParameter( const ShapeIndex& output_index) const; using AliasFn = - std::function; + std::function; // Iterates through each aliased output and input. void ForEachAlias(AliasFn fn) const; using AliasFnWithStatus = - std::function; + std::function; // Verifies that the given config is valid for the given module. // Specifically, the config's input and output should be in-bound and size of @@ -90,9 +121,10 @@ class HloInputOutputAliasConfig { private: // A ShapeTree which indicates the list of buffers that's expected to be // aliased. The key on this shape tree represents the output index. The value - // is a pair of parameter number and index into the buffer. If the value is - // nullopt, it means there is no parameter aliasing for this output. - ShapeTree>> alias_; + // is an Alias data structure which defines the input parameter coordinates. + // If the value is nullopt, it means there is no parameter aliasing for this + // output. + ShapeTree> alias_; }; std::ostream& operator<<(std::ostream& out, diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc index aeb9b0fdc8b6cca87731a2d4aae25120af6c3215..a46a107723de30176241aae01b268a8c10d991d3 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc @@ -45,11 +45,12 @@ class HloInputOutputAliasConfigTest : public HloTestBase { EXPECT_TRUE(aliased_output); EXPECT_EQ(aliased_output.value(), output_index); - absl::optional> aliased_param = + absl::optional aliased_param = config.GetAliasedParameter(output_index); EXPECT_TRUE(aliased_param); - EXPECT_EQ(aliased_param.value(), std::make_pair(param_number, param_index)); + EXPECT_EQ(aliased_param->parameter_number, param_number); + EXPECT_EQ(aliased_param->parameter_index, param_index); } void expect_not_aliased(const ShapeIndex& output_index, int64 param_number, @@ -60,11 +61,12 @@ class HloInputOutputAliasConfigTest : public HloTestBase { EXPECT_FALSE(aliased_output && aliased_output == output_index); - absl::optional> aliased_param = + absl::optional aliased_param = config.GetAliasedParameter(output_index); - EXPECT_FALSE(aliased_param && aliased_param->first == param_number && - aliased_param->second == param_index); + EXPECT_FALSE(aliased_param && + aliased_param->parameter_number == param_number && + aliased_param->parameter_index == param_index); } }; @@ -84,8 +86,10 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); expect_aliased(/*output_index=*/{0}, /*param_number=*/1, /*param_index=*/{}, config); @@ -114,11 +118,15 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, - /*param_index=*/{0})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, - /*param_index=*/{1})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); expect_aliased(/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, config); @@ -149,11 +157,15 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); @@ -176,8 +188,10 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); @@ -200,11 +214,15 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); - ASSERT_IS_NOT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1, - /*param_index=*/{})); + ASSERT_IS_NOT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 3e8903c95376ae1238b68280bbbb00b0db5a23a2..3c92554ad4ec48686d64c74a00f732a3bfee87bc 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" @@ -82,15 +83,14 @@ StatusOr> HloInstruction::CreateFromProto( return computation_map.at(proto.called_computation_ids(index)); }; - TF_RET_CHECK(std::all_of( - proto.operand_ids().begin(), proto.operand_ids().end(), - [&instruction_map](int64 id) { return instruction_map.contains(id); })) + TF_RET_CHECK( + absl::c_all_of(proto.operand_ids(), + [&](int64 id) { return instruction_map.contains(id); })) << proto.name() << " instruction contains invalid operand id(s)"; - TF_RET_CHECK(std::all_of( - proto.called_computation_ids().begin(), - proto.called_computation_ids().end(), - [&computation_map](int64 id) { return computation_map.contains(id); })) + TF_RET_CHECK( + absl::c_all_of(proto.called_computation_ids(), + [&](int64 id) { return computation_map.contains(id); })) << proto.name() << " instruction references invalid computation id(s)"; Shape shape(proto.shape()); @@ -311,7 +311,7 @@ StatusOr> HloInstruction::CreateFromProto( shape, operands(0), proto.exponent_bits(), proto.mantissa_bits()); break; case HloOpcode::kInfeed: { - TF_RET_CHECK(ShapeUtil::IsTuple(shape) && + TF_RET_CHECK(shape.IsTuple() && (ShapeUtil::TupleElementCount(shape) == 2)) << "Infeed should have a tuple shape with 2 operands, but has: " << shape; @@ -452,13 +452,43 @@ StatusOr> HloInstruction::CreateFromProto( CreatePad(shape, operands(0), operands(1), proto.padding_config()); break; case HloOpcode::kDynamicSlice: { - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "DynamicSlice instruction should have 2 operands but sees " - << proto.operand_ids_size(); std::vector slice_sizes(proto.dynamic_slice_sizes_size()); absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); + TF_RET_CHECK(proto.operand_ids_size() >= 1) + << "DynamicSlice instruction should have at least 1 operands but " + "sees " + << proto.operand_ids_size(); + // TODO(b/118437727): Old form, make the check unconditional. + if (proto.operand_ids_size() != 2 || operands(1)->shape().rank() != 1) { + auto expected_operands = 1 + operands(0)->shape().rank(); + TF_RET_CHECK(proto.operand_ids_size() == expected_operands) + << "DynamicSlice instruction should have " << expected_operands + << " operands, but has " << proto.operand_ids_size(); + } + const auto& operand_vector = all_operands(); + instruction = CreateDynamicSlice( + shape, operands(0), absl::MakeSpan(operand_vector).subspan(1), + slice_sizes); + break; + } + case HloOpcode::kDynamicUpdateSlice: { + TF_RET_CHECK(proto.operand_ids_size() >= 2) + << "DynamicUpdateSlice instruction should have at least 2 operands " + "but sees " + << proto.operand_ids_size(); + // TODO(b/118437727): Old form, make the check unconditional. + if (proto.operand_ids_size() != 3 || operands(2)->shape().rank() != 1) { + auto expected_operands = 2 + operands(0)->shape().rank(); + TF_RET_CHECK(proto.operand_ids_size() == expected_operands) + << "DynamicUpdateSlice instruction should have " + << expected_operands << " operands, but has " + << proto.operand_ids_size(); + } + const auto& operand_vector = all_operands(); instruction = - CreateDynamicSlice(shape, operands(0), operands(1), slice_sizes); + CreateDynamicUpdateSlice(shape, operands(0), operands(1), + absl::MakeSpan(operand_vector).subspan(2)); + break; } case HloOpcode::kGather: { @@ -628,7 +658,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, absl::Span operands) { if (opcode == HloOpcode::kCopy) { // It is impossible to copy an opaque shape, we don't know how big it is. - CHECK(!ShapeUtil::IsOpaque(shape)); + CHECK(!shape.IsOpaque()); } auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape)); for (auto operand : operands) { @@ -911,17 +941,17 @@ HloInstruction::CreateAddDependency(HloInstruction* data_operand, } /* static */ std::unique_ptr HloInstruction::CreateDynamicSlice( - const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, + const Shape& shape, HloInstruction* operand, + absl::Span start_indices, absl::Span slice_sizes) { return absl::make_unique( shape, operand, start_indices, slice_sizes); } /* static */ std::unique_ptr -HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, - HloInstruction* operand, - HloInstruction* update, - HloInstruction* start_indices) { +HloInstruction::CreateDynamicUpdateSlice( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + absl::Span start_indices) { return absl::make_unique( shape, operand, update, start_indices); } @@ -1039,7 +1069,7 @@ HloInstruction::CreateBroadcastSequence( const std::function)>& adder) { CHECK(ShapeUtil::IsScalar(operand->shape()) || - ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)); + operand->shape().rank() == output_shape.rank()); Shape broadcast_shape = ShapeUtil::ChangeElementType( output_shape, operand->shape().element_type()); // Do explicit broadcast for scalar. @@ -1055,7 +1085,7 @@ HloInstruction::CreateBroadcastSequence( // Do explicit broadcast for degenerate broadcast. std::vector broadcast_dimensions; std::vector reshaped_dimensions; - for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) { + for (int i = 0; i < operand->shape().rank(); i++) { if (operand->shape().dimensions(i) == output_shape.dimensions(i)) { broadcast_dimensions.push_back(i); reshaped_dimensions.push_back(operand->shape().dimensions(i)); @@ -1132,7 +1162,7 @@ HloInstruction::CreateBroadcastSequence( void HloInstruction::set_single_sharding(const HloSharding& sharding) { CHECK(!sharding.IsTuple()) << sharding; - if (ShapeUtil::IsTuple(shape())) { + if (shape().IsTuple()) { set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape()))); } else { set_sharding(sharding); @@ -1382,9 +1412,8 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateReshape(shape, new_operands[0]); break; case HloOpcode::kDynamicUpdateSlice: - CHECK_EQ(new_operands.size(), 3); clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1], - new_operands[2]); + new_operands.subspan(2)); break; case HloOpcode::kTuple: clone = CreateTuple(new_operands); @@ -1546,12 +1575,10 @@ HloInstruction::InstructionVector HloInstruction::unique_operands() const { Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { TF_RET_CHECK(instruction->parent() == parent()); - if (std::find(control_successors_.begin(), control_successors_.end(), - instruction) == control_successors_.end()) { + if (!absl::c_linear_search(control_successors_, instruction)) { control_successors_.push_back(instruction); - TF_RET_CHECK(std::find(instruction->control_predecessors_.begin(), - instruction->control_predecessors_.end(), - this) == instruction->control_predecessors_.end()); + TF_RET_CHECK( + !absl::c_linear_search(instruction->control_predecessors_, this)); instruction->control_predecessors_.push_back(this); } return Status::OK(); @@ -1800,7 +1827,7 @@ void HloInstruction::RemoveUser(HloInstruction* user) { user_set_.erase(set_it); // This is linear in the number of the users, but a vector provides a stable // iteration order and much faster traversal. - auto vec_it = std::find(users_.begin(), users_.end(), user); + auto vec_it = absl::c_find(users_, user); CHECK(vec_it != users_.end()); users_.erase(vec_it); } @@ -1818,8 +1845,7 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user, RemoveUser(user); - TF_RET_CHECK( - std::count(user->operands_.begin(), user->operands_.end(), this) >= 0); + TF_RET_CHECK(absl::c_count(user->operands_, this) >= 0); std::replace(user->operands_.begin(), user->operands_.end(), this, new_producer); new_producer->AddUser(user); @@ -1832,6 +1858,16 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user, Status HloInstruction::ReplaceOperandWith(int64 operand_num, HloInstruction* new_operand) { + auto old_operand = operand(operand_num); + TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(), + new_operand->shape())) + << old_operand->shape() << " is not compatible with " + << new_operand->shape(); + return ReplaceOperandWithDifferentShape(operand_num, new_operand); +} + +Status HloInstruction::ReplaceOperandWithDifferentShape( + int64 operand_num, HloInstruction* new_operand) { TF_RET_CHECK(operand_num >= 0); TF_RET_CHECK(operand_num < operand_count()); HloInstruction* old_operand = mutable_operand(operand_num); @@ -1839,17 +1875,12 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, return Status::OK(); } - TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(), - new_operand->shape())) - << old_operand->shape() << " is not compatible with " - << new_operand->shape(); operands_[operand_num] = new_operand; VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with " << new_operand->name() << ", was " << old_operand->name(); - if (std::find(operands_.begin(), operands_.end(), old_operand) == - operands_.end()) { + if (!absl::c_linear_search(operands_, old_operand)) { old_operand->RemoveUser(this); } new_operand->AddUser(this); @@ -1857,6 +1888,14 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, } Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { + TF_RET_CHECK( + ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape())) + << shape() << " is not compatible with " << new_producer->shape(); + return ReplaceAllUsesWithDifferentShape(new_producer); +} + +Status HloInstruction::ReplaceAllUsesWithDifferentShape( + HloInstruction* new_producer) { bool new_producer_is_user = false; for (HloInstruction* user : users()) { if (user == new_producer) { @@ -1881,7 +1920,8 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { AddUser(new_producer); } if (parent_ && parent_->root_instruction() == this) { - parent_->set_root_instruction(new_producer); + parent_->set_root_instruction(new_producer, + /*accept_different_shape=*/true); } return Status::OK(); @@ -2824,7 +2864,7 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { } return UseKind::kReuse; case HloOpcode::kDynamicUpdateSlice: - // Dynamic-update-slice reuses only operand 2 (start_indices). + // Dynamic-update-slice reuses only start_indices. if (i == 0 || i == 1) { return UseKind::kUse; } @@ -2877,10 +2917,10 @@ StatusOr StringToFusionKind( string PaddingConfigToString(const PaddingConfig& padding) { bool has_interior_padding = - std::any_of(padding.dimensions().begin(), padding.dimensions().end(), - [](const PaddingConfig::PaddingConfigDimension& dim) { - return dim.interior_padding() != 0; - }); + absl::c_any_of(padding.dimensions(), + [](const PaddingConfig::PaddingConfigDimension& dim) { + return dim.interior_padding() != 0; + }); return StrJoin( padding.dimensions(), "x", [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 36e1ab49319a3e28143ef4d08888c68c86fbcf62..2c29b6c243bffccc346af12277dd4fc061250cbe 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -558,13 +558,14 @@ class HloInstruction { // 'slice_sizes'. static std::unique_ptr CreateDynamicSlice( const Shape& shape, HloInstruction* operand, - HloInstruction* start_indices, absl::Span slice_sizes); + absl::Span start_indices, + absl::Span slice_sizes); // Creates a dynamic update slice instruction, which updates a slice // of 'operand' with 'update' and 'start_indices'. static std::unique_ptr CreateDynamicUpdateSlice( const Shape& shape, HloInstruction* operand, HloInstruction* update, - HloInstruction* start_indices); + absl::Span start_indices); // Creates a concatenate instruction, where the operands are concatenated on // the provided dimension. @@ -928,11 +929,16 @@ class HloInstruction { // operands of it which could be created due to this replacement. Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); - // Replaces the specified operand with new_operand. + // Replaces the specified operand with new_operand. The old and new operands + // must have compatible shapes ignoring floating-point precision. // // This function does NOT remove duplicated operands even if this instruction // is a fusion, so that the existing operand numbers do not change. - Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand); + Status ReplaceOperandWith(int64 operand_num, HloInstruction* new_operand); + + // Same as ReplaceOperandWith(), but new_operand can have a different shape. + Status ReplaceOperandWithDifferentShape(int64 operand_num, + HloInstruction* new_operand); // Replaces all uses of this instruction with the new producer. If // new_producer is a user of this instruction then new_producer remains a use @@ -941,10 +947,16 @@ class HloInstruction { // If this instruction is the root of its computation, sets the computation's // root to new_producer. // + // The new producer must have a compatible shape ignoring floating-point + // precision. + // // If a user is a fusion instruction, this function will remove any duplicated // operands of it which could be created due to this replacement. Status ReplaceAllUsesWith(HloInstruction* new_producer); + // Same as ReplaceAllUsesWith, but new_producer can have a different shape. + Status ReplaceAllUsesWithDifferentShape(HloInstruction* new_producer); + // Performs a postorder DFS visit using this node as the root. If // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when // complete. If ignore_control_predecessors is true, instructions only diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 8048e332cb57747286758b75773b29ba154aa888..35f031f29a7aca8db7ebe2fbcfdcebb7a778d703 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -55,13 +56,13 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { } Status HandleParameter(HloInstruction* parameter) override { - EXPECT_EQ(0, count_.count(parameter)); + EXPECT_FALSE(count_.contains(parameter)); count_[parameter] = GetCountsForNode(parameter); return Status::OK(); } Status HandleConstant(HloInstruction* constant) override { - EXPECT_EQ(0, count_.count(constant)); + EXPECT_FALSE(count_.contains(constant)); count_[constant] = GetCountsForNode(constant); return Status::OK(); } @@ -69,25 +70,25 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { Status HandleAdd(HloInstruction* add) override { auto lhs = add->operand(0); auto rhs = add->operand(1); - EXPECT_EQ(0, count_.count(add)); - EXPECT_GT(count_.count(lhs), 0); - EXPECT_GT(count_.count(rhs), 0); + EXPECT_FALSE(count_.contains(add)); + EXPECT_TRUE(count_.contains(lhs)); + EXPECT_TRUE(count_.contains(rhs)); count_[add] = GetCountsForNode(add); return Status::OK(); } Status HandleNegate(HloInstruction* negate) override { auto operand = negate->operand(0); - EXPECT_EQ(0, count_.count(negate)); - EXPECT_GT(count_.count(operand), 0); + EXPECT_FALSE(count_.contains(negate)); + EXPECT_TRUE(count_.contains(operand)); count_[negate] = GetCountsForNode(negate); return Status::OK(); } Status HandleMap(HloInstruction* map) override { - EXPECT_EQ(0, count_.count(map)); + EXPECT_FALSE(count_.contains(map)); for (HloInstruction* arg : map->operands()) { - EXPECT_GT(count_.count(arg), 0); + EXPECT_TRUE(count_.contains(arg)); } count_[map] = GetCountsForNode(map); return Status::OK(); @@ -96,9 +97,9 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { Status HandleReduce(HloInstruction* reduce) override { auto arg = reduce->operand(0); auto init_value = reduce->operand(1); - EXPECT_EQ(0, count_.count(reduce)); - EXPECT_GT(count_.count(arg), 0); - EXPECT_GT(count_.count(init_value), 0); + EXPECT_FALSE(count_.contains(reduce)); + EXPECT_TRUE(count_.contains(arg)); + EXPECT_TRUE(count_.contains(init_value)); count_[reduce] = GetCountsForNode(reduce); return Status::OK(); } @@ -128,7 +129,7 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { } // Counters for HLOs. Maps HLO to a NumOpsAndUsers. - std::unordered_map count_; + absl::flat_hash_map count_; }; TEST_F(HloInstructionTest, BasicProperties) { @@ -137,7 +138,7 @@ TEST_F(HloInstructionTest, BasicProperties) { EXPECT_EQ(HloOpcode::kParameter, parameter->opcode()); EXPECT_TRUE(ShapeUtil::IsScalarWithElementType(parameter->shape(), F32)); EXPECT_FALSE(ShapeUtil::IsScalarWithElementType(parameter->shape(), S32)); - EXPECT_EQ(0, parameter->operand_count()); + EXPECT_FALSE(parameter->operand_count()); } TEST_F(HloInstructionTest, UserWithTwoOperands) { @@ -981,9 +982,9 @@ TEST_F(HloInstructionTest, FunctionVisitor) { module->AddEntryComputation(builder.Build()); int visit_num = 0; - std::unordered_map visit_order; + absl::flat_hash_map visit_order; EXPECT_IS_OK(add->Accept([&visit_num, &visit_order](HloInstruction* inst) { - EXPECT_EQ(0, visit_order.count(inst)); + EXPECT_FALSE(visit_order.contains(inst)); visit_order[inst] = visit_num; visit_num++; return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 756e260b60dcda660e89c211862c8c5800439f2c..b01f01ef012b4c366035dc16b44508d71ad07d79 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -42,11 +42,9 @@ using absl::StrJoin; bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, const HloInstruction* operand) { std::vector operand_indices = instruction->OperandIndices(operand); - return std::all_of( - operand_indices.begin(), operand_indices.end(), - [instruction](int64 operand_index) { - return instruction->IsElementwiseOnOperand(operand_index); - }); + return absl::c_all_of(operand_indices, [instruction](int64 operand_index) { + return instruction->IsElementwiseOnOperand(operand_index); + }); } string PrecisionConfigToString(const PrecisionConfig& precision_config) { @@ -385,6 +383,15 @@ HloInstructionProto HloAllReduceInstruction::ToProto() const { return proto; } +bool HloAllReduceInstruction::IsNoop() const { + for (auto replica_group : replica_groups()) { + if (replica_group.replica_ids().size() != 1) { + return false; + } + } + return !all_reduce_id(); +} + std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { std::vector result = @@ -734,7 +741,7 @@ HloMapInstruction::HloMapInstruction(const Shape& shape, AppendComputation(map_computation); // TODO(b/65689298) Remove code below once Map is generalized to accept // arbitrary map dimensions. - dimensions_.resize(ShapeUtil::Rank(shape)); + dimensions_.resize(shape.rank()); std::iota(dimensions_.begin(), dimensions_.end(), 0); } @@ -814,8 +821,7 @@ std::vector HloSliceInstruction::ExtraAttributesToStringImpl( std::vector bounds; bounds.reserve(slice_starts_.size()); const bool omit_stride = - std::all_of(slice_strides_.begin(), slice_strides_.end(), - [](int64 stride) { return stride == 1; }); + absl::c_all_of(slice_strides_, [](int64 stride) { return stride == 1; }); for (int i = 0; i < slice_starts_.size(); ++i) { string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]); bounds.push_back( @@ -866,7 +872,7 @@ void HloConstantInstruction::RelayoutConstant(const Layout& new_layout, const ShapeIndex& shape_index) { Shape* mutable_array_subshape = ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index); - CHECK(ShapeUtil::IsArray(*mutable_array_subshape)); + CHECK(mutable_array_subshape->IsArray()); // Normally array_subshape will always have a layout, but this invariant is // temporarily broken in LayoutAssignment::AssignLayouts. @@ -900,7 +906,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( string operands; // For constants, show the actual value in place of an empty operand list. if (literal_.has_value() && - ((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || + ((shape().IsArray() && ShapeUtil::ElementsIn(shape()) <= 10) || options.print_large_constants())) { // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. @@ -1051,8 +1057,7 @@ HloInstruction* HloFusionInstruction::AddFusionOperand( void HloFusionInstruction::MergeFusionInstruction( HloFusionInstruction* instruction_to_merge) { - CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) != - operands().end()); + CHECK(absl::c_linear_search(operands(), instruction_to_merge)); // Clone the instruction from which to merge fused instructions. std::unique_ptr cloned = instruction_to_merge->Clone(); HloFusionInstruction* cloned_fusion = @@ -1219,8 +1224,8 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( // corresponding fused parameter instruction. Renumber parameters as // necessary to make parameter numbers consistent with their index in the // fused_parameter_ vector. - bool in_operand_list = std::find(operands().begin(), operands().end(), - instruction_to_fuse) != operands().end(); + bool in_operand_list = + absl::c_linear_search(operands(), instruction_to_fuse); CHECK(add_output || in_operand_list); if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { // We assume all uses of a kTuple operation are GTE ops, not another @@ -1324,7 +1329,7 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( if (newly_created_tuple_instr) { HloInstruction* new_instr = parent()->AddInstruction( HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0)); - TF_CHECK_OK(ReplaceAllUsesWith(new_instr)); + TF_CHECK_OK(ReplaceAllUsesWithDifferentShape(new_instr)); } int64 index = tuple_elements.size(); if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { @@ -1706,6 +1711,10 @@ std::vector HloConvolutionInstruction::ExtraAttributesToStringImpl( extra.push_back(StrCat("feature_group_count=", feature_group_count_)); } + if (batch_group_count_ != 1) { + extra.push_back(StrCat("batch_group_count=", batch_group_count_)); + } + string precision_config_string = PrecisionConfigToString(precision_config_); if (!precision_config_string.empty()) { extra.push_back(precision_config_string); @@ -2007,6 +2016,18 @@ HloDynamicSliceInstruction::HloDynamicSliceInstruction( AppendOperand(start_indices); } +HloDynamicSliceInstruction::HloDynamicSliceInstruction( + const Shape& shape, HloInstruction* operand, + absl::Span start_indices, + absl::Span slice_sizes) + : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape), + dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) { + AppendOperand(operand); + for (HloInstruction* index : start_indices) { + AppendOperand(index); + } +} + HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices) @@ -2016,6 +2037,17 @@ HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction( AppendOperand(start_indices); } +HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + absl::Span start_indices) + : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) { + AppendOperand(operand); + AppendOperand(update); + for (HloInstruction* index : start_indices) { + AppendOperand(index); + } +} + HloInstructionProto HloDynamicSliceInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); for (int64 slice_size : dynamic_slice_sizes_) { @@ -2041,9 +2073,14 @@ std::unique_ptr HloDynamicSliceInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 2); - return absl::make_unique( - shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); + if (new_operands.size() == 2 && new_operands[1]->shape().rank() == 1) { + // TODO(b/118437727): Old form, remove this path. + return absl::make_unique( + shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); + } else { + return absl::make_unique( + shape, new_operands[0], new_operands.subspan(1), dynamic_slice_sizes_); + } } HloGatherInstruction::HloGatherInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index ca212c7f2c98f75ceefc14b7fbc2a1f530c06cf7..1b4a94753cda8aba8d50836b9d51b7c3fd5807f6 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -253,6 +253,10 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns true if the AllReduce does no communication, so it's equivalent + // to a mem copy. + bool IsNoop() const; + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -1183,7 +1187,22 @@ class HloDynamicIndexInstruction : public HloInstruction { public: explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape) : HloInstruction(opcode, shape) {} - virtual int64 index_operand_number() const = 0; + virtual int64 first_index_operand_number() const = 0; + + // Returns a subspan of operands which represent the start indices. + absl::Span index_operands() const { + return absl::MakeSpan(operands()).subspan(first_index_operand_number()); + } + + // Returns the shapes of the index operands. + std::vector index_shapes() const { + std::vector shapes; + auto indices = index_operands(); + for (const HloInstruction* index : indices) { + shapes.push_back(index->shape()); + } + return shapes; + } }; class HloDynamicSliceInstruction : public HloDynamicIndexInstruction { @@ -1192,6 +1211,10 @@ class HloDynamicSliceInstruction : public HloDynamicIndexInstruction { HloInstruction* operand, HloInstruction* start_indices, absl::Span slice_sizes); + explicit HloDynamicSliceInstruction( + const Shape& shape, HloInstruction* operand, + absl::Span start_indices, + absl::Span slice_sizes); // Old methods kept for smooth subclassing transition END. // Returns the size of the slice in the given dimension for a dynamic // slice node. @@ -1204,7 +1227,7 @@ class HloDynamicSliceInstruction : public HloDynamicIndexInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; - int64 index_operand_number() const override { return 1; } + int64 first_index_operand_number() const override { return 1; } private: std::vector ExtraAttributesToStringImpl( @@ -1229,8 +1252,11 @@ class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction { HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices); + explicit HloDynamicUpdateSliceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + absl::Span start_indices); - int64 index_operand_number() const override { return 2; } + int64 first_index_operand_number() const override { return 2; } }; class HloGatherInstruction : public HloInstruction { diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index dc712e5e42c449737bf4415f3a5e3eb9d81d9be4..798760885dcd55e0a1cbdf403fa160347d67fc3a 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/strings/ascii.h" #include "absl/strings/escaping.h" #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" @@ -37,8 +38,8 @@ constexpr int kError = -2; // [a-zA-Z0-9_.-] bool IsIdentifierChar(char c) { - return isalnum(static_cast(c)) || c == '-' || c == '.' || - c == '_'; + return absl::ascii_isalnum(static_cast(c)) || c == '-' || + c == '.' || c == '_'; } } // namespace @@ -105,7 +106,7 @@ TokKind HloLexer::LexToken() { switch (current_char) { default: // [a-zA-Z_] - if (isalpha(static_cast(current_char)) || + if (absl::ascii_isalpha(static_cast(current_char)) || current_char == '_') { return LexIdentifier(); } @@ -140,6 +141,12 @@ TokKind HloLexer::LexToken() { return LexNumberOrPattern(); case '=': return TokKind::kEqual; + case '<': + if (current_char == '<' && PeekCurrentChar() == '=') { + current_ptr_++; + return TokKind::kLeq; + } + return TokKind::kError; case ',': return TokKind::kComma; case '%': @@ -294,7 +301,7 @@ TokKind HloLexer::LexIdentifier() { // name ::= [a-zA-Z_][a-zA-Z0-9_.-]* TokKind HloLexer::LexPercent() { const char* name_start = current_ptr_; - if (isalpha(static_cast(PeekCurrentChar())) || + if (absl::ascii_isalpha(static_cast(PeekCurrentChar())) || PeekCurrentChar() == '_') { current_ptr_++; while (IsIdentifierChar(PeekCurrentChar())) { @@ -462,6 +469,8 @@ string TokKindToString(TokKind kind) { return "kRparen"; case TokKind::kArrow: return "kArrow"; + case TokKind::kLeq: + return "kLeq"; case TokKind::kw_HloModule: return "kw_HloModule"; case TokKind::kw_ENTRY: diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index 41f5043904a2622814154693679a0e27cb92f642..94fac3cd8e9da7f273e7e521e21510f5188702e6 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -49,6 +49,7 @@ enum class TokKind { kRparen, // ( ) kArrow, // -> + kLeq, // <= // Keywords kw_HloModule, diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc index 5bf055f3c012fef687cdc275d62efdf2d4cd5e5c..e14bcfa7f67e736a4d04f5b236fb2df02cf150e0 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" @@ -36,11 +37,11 @@ namespace xla { namespace { using Worklist = std::deque; -using Workset = std::unordered_set; +using Workset = absl::flat_hash_set; void AddToWorklist(const HloInstruction* instruction, Worklist* worklist, Workset* workset) { - if (workset->count(instruction) == 0) { + if (!workset->contains(instruction)) { worklist->push_back(instruction); workset->insert(instruction); VLOG(3) << "ADD instruction: " << instruction->name(); diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h index 7227bfb27c74758d2b79e404afc9eb97a1ca894d..76cc29cbb7848eb424d07abf11a95ffd59e9eed6 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -118,7 +118,7 @@ class HloTrivialScheduler : public HloModulePass { }; // A trivial pass which clears the schedule currently set on the -// HloModule. After this pass runs HloModudle::has_schedule will return false. +// HloModule. After this pass runs HloModule::has_schedule will return false. class HloDescheduler : public HloModulePass { public: HloDescheduler() = default; diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index fe8371384c0fa3900a9022f101ff0b296439cf16..258f918f47a313b4b89fb260457b1b119dc16177 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -107,11 +107,10 @@ HloComputation* HloModule::AddEntryComputation( } Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { - auto it = - std::find_if(computations_.begin(), computations_.end(), - [&to_remove](const std::unique_ptr& comp) { - return comp.get() == to_remove; - }); + auto it = absl::c_find_if( + computations_, [&to_remove](const std::unique_ptr& comp) { + return comp.get() == to_remove; + }); TF_RET_CHECK(it->get() == to_remove); computations_.erase(it); return Status::OK(); @@ -304,11 +303,10 @@ StatusOr> HloModule::CreateFromProto( auto module = absl::make_unique(proto.name(), module_config); // Sort the computations in the proto id's order. - std::sort(computations.begin(), computations.end(), - [&](const std::unique_ptr& a, - const std::unique_ptr& b) { - return to_proto_id[a.get()] < to_proto_id[b.get()]; - }); + absl::c_sort(computations, [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); // Add sorted computations to the module. for (auto& computation : computations) { @@ -392,15 +390,12 @@ namespace { // Returns whether `hlo` is used outside the given subcomputation. // `instructions_in_subcomputation` is the instruction set of the given // subcomputation. -bool IsUsedOutsideSubcomputation( - const HloInstruction& hlo, - const std::unordered_set& instructions_in_subcomputation) { - for (HloInstruction* user : hlo.users()) { - if (!instructions_in_subcomputation.count(user)) { - return true; - } - } - return false; +bool IsUsedOutsideSubcomputation(const HloInstruction& hlo, + const absl::flat_hash_set& + instructions_in_subcomputation) { + return absl::c_any_of(hlo.users(), [&](HloInstruction* user) { + return !instructions_in_subcomputation.contains(user); + }); } } // anonymous namespace @@ -411,9 +406,9 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( // A map from original instructions to their counterparts in the new outlined // function. - std::unordered_map outlined_instructions; + absl::flat_hash_map outlined_instructions; // A set that contains all instructions to be outlined. - std::unordered_set instruction_set_to_outline( + absl::flat_hash_set instruction_set_to_outline( instructions_to_outline.begin(), instructions_to_outline.end()); std::vector arguments; std::vector outputs; @@ -502,7 +497,7 @@ std::vector HloModule::MakeComputationPostOrder() const { // First determine all root computations by building a set of nonroot // computations (computations which are called by an instruction in the // module). - std::set nonroot_computations; + absl::flat_hash_set nonroot_computations; for (auto& computation : computations_) { for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : @@ -515,19 +510,19 @@ std::vector HloModule::MakeComputationPostOrder() const { // Keep track of computations which have already been added to the post // order. This prevents duplication as an embedded computation may be called // from two different root computations. - std::set added_computations; + absl::flat_hash_set added_computations; std::vector post_order; for (auto& computation : computations_) { - if (nonroot_computations.count(computation.get()) == 0) { + if (!nonroot_computations.contains(computation.get())) { for (HloComputation* embedded_computation : computation->MakeEmbeddedComputationsList()) { - if (added_computations.count(embedded_computation) == 0) { + if (!added_computations.contains(embedded_computation)) { post_order.push_back(embedded_computation); added_computations.insert(embedded_computation); } } // Root computations should only be encountered once. - CHECK_EQ(0, added_computations.count(computation.get())); + CHECK(!added_computations.contains(computation.get())); post_order.push_back(computation.get()); added_computations.insert(computation.get()); } diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc index 31d26cc51e8217234526bbfeb83510aadf2c27b5..6b72ba128664d27c51aa8dcfa61fe959a0160c73 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc @@ -49,7 +49,7 @@ StatusOr RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { auto* while_body_param = while_body_comp->parameter_instruction(0); auto* while_body_root = while_body_comp->root_instruction(); - if (!ShapeUtil::IsTuple(xla_while->shape()) || + if (!xla_while->shape().IsTuple() || while_body_root->opcode() != HloOpcode::kTuple) { // Only run DCE on tuple-shaped while loops where body root is Tuple, // with no I/O instructions. diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index e535b7d74943943069b4d795cf999a3b1e963360..f6e2866204955ac024c2b6f972de449cc3df4c15 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -38,9 +38,7 @@ class HloModuleDceTest : public HloTestBase { // Returns whether the given instruction exists in the given computation. bool HasInstruction(const HloComputation& computation, const HloInstruction* instruction) { - return std::find(computation.instructions().begin(), - computation.instructions().end(), - instruction) != computation.instructions().end(); + return absl::c_linear_search(computation.instructions(), instruction); } // Returns whether the while instruction with name 'while_name' in diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index b4aac4c8076cb69647d42c6243bc969d06d0709e..47734bc55cc00d605f4e318400be88639450343c 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -79,36 +79,36 @@ Status HloModuleGroupMetadata::Build() { return Status::OK(); } - std::vector peers; - if (IsChannelInstruction(hlo)) { - peers.push_back(PeerComputation(hlo)); - } else if (hlo->IsCrossModuleAllReduce()) { - for (HloInstruction* instr : GetAllReduceGroup(*hlo->all_reduce_id())) { - if (instr == hlo) { - continue; + if (IsChannelInstruction(hlo) || hlo->IsCrossModuleAllReduce()) { + std::vector peers; + if (IsChannelInstruction(hlo)) { + peers.push_back(PeerComputation(hlo)); + } else if (hlo->IsCrossModuleAllReduce()) { + for (HloInstruction* instr : GetAllReduceGroup(*hlo->all_reduce_id())) { + if (instr == hlo) { + continue; + } + peers.push_back(instr->parent()); } - peers.push_back(instr->parent()); } - } - - // Add the parent computation of this channel (or all-reduce) instruction - // and its peer computation(s) (both must be while computations) as - // companions. - for (HloComputation* peer_computation : peers) { - const TrackedInstruction* peer_tracked = - GetTrackedInstruction(peer_computation); - TF_RET_CHECK(peer_tracked != nullptr) - << "Peer instruction is not a possible companion"; - TF_RET_CHECK(*tracked == *peer_tracked) - << "Peer instruction does not match the computation kind"; - TF_RETURN_IF_ERROR( - AddCompanion(tracked->instruction(), peer_tracked->instruction())); - tracked_instructions_comms_[tracked->instruction()].push_back(hlo); - } - // Add the parents of companion instructions (they must be all of the same - // kind of instructions, opcode wise) as companions. - if (IsCompanionInstruction(hlo)) { + // Add the parent computation of this channel (or all-reduce) instruction + // and its peer computation(s) (both must be while computations) as + // companions. + for (HloComputation* peer_computation : peers) { + const TrackedInstruction* peer_tracked = + GetTrackedInstruction(peer_computation); + TF_RET_CHECK(peer_tracked != nullptr) + << "Peer instruction is not a possible companion"; + TF_RET_CHECK(*tracked == *peer_tracked) + << "Peer instruction does not match the computation kind"; + TF_RETURN_IF_ERROR( + AddCompanion(tracked->instruction(), peer_tracked->instruction())); + tracked_instructions_comms_[tracked->instruction()].push_back(hlo); + } + } else if (IsCompanionInstruction(hlo)) { + // Add the parents of companion instructions (they must be all of the same + // kind of instructions, opcode wise) as companions. for (HloInstruction* companion : Companions(hlo)) { const TrackedInstruction* companion_tracked = GetTrackedInstruction(companion->parent()); @@ -118,6 +118,7 @@ Status HloModuleGroupMetadata::Build() { companion_tracked->instruction())); } } + return Status::OK(); }; @@ -198,7 +199,7 @@ bool HloModuleGroupMetadata::IsChannelInstruction( } bool HloModuleGroupMetadata::IsCompanionInstruction(HloInstruction* hlo) const { - return companion_set_index_.count(hlo) > 0; + return companion_set_index_.contains(hlo); } bool HloModuleGroupMetadata::InstructionCommunicates( @@ -509,7 +510,7 @@ Status HloModuleGroupMetadata::CheckCommunicatingInstruction( HloComputation* computation = instruction->parent(); const HloModule* module = computation->parent(); if (module->entry_computation() == computation || - tracked_instructions_.count(computation) > 0) { + tracked_instructions_.contains(computation)) { return Status::OK(); } return FailedPrecondition("channel is used in disallowed computation"); diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 928df0f5a7444ad877961a5de970c752e1d024da..3ed95c10504141139d83eb8679a0b8144b15ad0d 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -38,7 +38,7 @@ namespace xla { // Class for bookkeeping the information on the given modules, in particular on // the interaction between computations. // -// Companion instructions are one of the information collected as we build the +// Companion instructions are one piece of information collected as we build the // metadata. For example, for each While instruction, companion instructions // refer to a set of While instructions in other computations that communicate // with each other. @@ -51,6 +51,13 @@ namespace xla { // } While_4() { Recv(0) } // } // +// Each instruction can belong to at most one companion set: While_0 and While_5 +// are in the same set even though they don't communicate with each other, +// because they both communicate with While_2. +// +// A send and the matching recv must both have the same level of nesting of +// companion instructions. +// // Companion instructions are used to detect cycles in the graph and also for // global scheduling. class HloModuleGroupMetadata { @@ -171,7 +178,7 @@ class HloModuleGroupMetadata { // Precondition: IsCompanionWhile(instruction) is true. const std::vector& Companions( const HloInstruction* instruction) const { - CHECK_EQ(companion_set_index_.count(instruction), 1); + CHECK(companion_set_index_.contains(instruction)); return companion_set(companion_set_index_.at(instruction)); } @@ -215,11 +222,8 @@ class HloModuleGroupMetadata { // * Each channel has all 4 instructions (Send, Recv, SendDone, RecvDone). // * The shape of channel instructions match. // * The nest level of channel instructions match. - // * Channel instructions are used in allowed computations; i.e., in the + // * Channel instructions are used in allowed computations, i.e., in the // entry computation of the module or condition/body of While computations. - // - // TODO(b/62064342): Currently, HloModuleGroupScheduler checks if there is a - // cycle in the graph, but it would be good to verify here. Status VerifyChannelInstructions(); // Adds metadata that the given two instructions are companions. @@ -231,8 +235,8 @@ class HloModuleGroupMetadata { Status CheckCommunicatingInstruction(HloInstruction* instruction) const; // Performs a consistency check on the companion sets built for the input - // modules. Check that a companion set does not include instructions from the - // same module/device. + // modules. Checks that each instruction in a companion set is in a different + // module/device. Status VerifyCompanionSets() const; // Retrieves a pointer to the stored TrackedInstruction associated with a diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index fddeb5f0a27a43ff9ca8b2b5d314bcfe91aaf0e6..91417bd2d9a6ca8a5192a37302e6a91e49a94d77 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -198,6 +198,8 @@ std::vector HloModuleGroupUtil::RootInstructions( for (HloComputation* computation : computations) { for (HloInstruction* instruction : computation->instructions()) { if (GlobalSuccessors(instruction).empty()) { + // An instruction that has no successors, e.g., an unused instruction, + // is in roots, even though it's not the ROOT of its computation. roots.push_back(instruction); } } diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h index f21b44bcd98d77b831de5d8a6afa4f9ddd91d15d..862666b48c9aa423ba4eeea3052c17fcc1064fd2 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h @@ -49,7 +49,7 @@ class HloModuleGroupUtil { // Returns all unique successors of the instruction. This includes: // * successors in the same computation: users and control successors // * Send is a successor of Recv - // * RecvDone is a predecessor of Send + // * RecvDone is a successor of Send // * successors of companions (if the instruction is a companion while) // * successors' companions (for any successor that is a companion while) std::vector GlobalSuccessors(HloInstruction* instruction); diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index ca6a154809be46d6a0305c29e2b89219de408019..0cec61c257bb84e467290fb52ec9063a32ed558d 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -367,7 +367,7 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation( const HloInstruction* a, const HloInstruction* b) const { CHECK_EQ(a->parent(), b->parent()); // If either instruction is not in the order, then 'a' and 'b' are unordered. - if (order_position_.count(a) == 0 || order_position_.count(b) == 0) { + if (!order_position_.contains(a) || !order_position_.contains(b)) { return false; } return order_position_.at(a) < order_position_.at(b); diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 44643951c14fb3a210b27064ffac4b99734bca0a..638396308c2a9c1f20e47f78b594d54f07c0c4e5 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -257,7 +257,8 @@ class HloParser { bool ParseName(string* result); bool ParseAttributeName(string* result); bool ParseString(string* result); - bool ParseDimensionSizes(std::vector* dimension_sizes); + bool ParseDimensionSizes(std::vector* dimension_sizes, + std::vector* dynamic_dimensions); bool ParseShape(Shape* result); bool ParseLayout(Layout* layout); bool ParseOpcode(HloOpcode* result); @@ -1170,24 +1171,39 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional> dynamic_slice_sizes; attrs["dynamic_slice_sizes"] = { /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes}; - if (!ParseOperands(&operands, /*expected_size=*/2) || - !ParseAttributes(attrs)) { + LocTy loc = lexer_.GetLoc(); + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } + if (operands.empty()) { + return Error(loc, "Expected at least one operand."); + } + if (!(operands.size() == 2 && operands[1]->shape().rank() == 1) && + operands.size() != 1 + operands[0]->shape().rank()) { + return Error(loc, "Wrong number of operands."); + } instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice( - shape, /*operand=*/operands[0], /*start_indices=*/operands[1], + shape, /*operand=*/operands[0], + /*start_indices=*/absl::MakeSpan(operands).subspan(1), *dynamic_slice_sizes)); break; } case HloOpcode::kDynamicUpdateSlice: { - if (!ParseOperands(&operands, /*expected_size=*/3) || - !ParseAttributes(attrs)) { + LocTy loc = lexer_.GetLoc(); + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } + if (operands.size() < 2) { + return Error(loc, "Expected at least two operands."); + } + if (!(operands.size() == 3 && operands[2]->shape().rank() == 1) && + operands.size() != 2 + operands[0]->shape().rank()) { + return Error(loc, "Wrong number of operands."); + } instruction = builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( shape, /*operand=*/operands[0], /*update=*/operands[1], - /*start_indices=*/operands[2])); + /*start_indices=*/absl::MakeSpan(operands).subspan(2))); break; } case HloOpcode::kTranspose: { @@ -1287,7 +1303,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, // the infeed instruction. ShapeUtil::GetTupleElementShape will check fail // if the shape is not a non-empty tuple, so add guard so an error message // can be emitted instead of a check fail - if (!ShapeUtil::IsTuple(shape) && !ShapeUtil::IsEmptyTuple(shape)) { + if (!shape.IsTuple() && !ShapeUtil::IsEmptyTuple(shape)) { return Error(lexer_.GetLoc(), "infeed must have a non-empty tuple shape"); } @@ -1931,8 +1947,8 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value, // ::= tuple // ::= non_tuple bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) { - return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape) - : ParseNonTupleLiteral(literal, shape); + return shape.IsTuple() ? ParseTupleLiteral(literal, shape) + : ParseNonTupleLiteral(literal, shape); } // tuple @@ -1980,7 +1996,7 @@ bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { } bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { - const tensorflow::int64 rank = ShapeUtil::Rank(shape); + const tensorflow::int64 rank = shape.rank(); // Create a literal with the given shape in default layout. *literal = LiteralUtil::CreateFromDimensions( shape.element_type(), AsInt64Slice(shape.dimensions())); @@ -2145,7 +2161,7 @@ template bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) { std::vector index; - tensorflow::int64 rank = ShapeUtil::Rank(shape); + tensorflow::int64 rank = shape.rank(); *literal = Literal(shape); @@ -2730,7 +2746,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( } auto is_unique = [](string str) -> bool { - std::sort(str.begin(), str.end()); + absl::c_sort(str); return std::unique(str.begin(), str.end()) == str.end(); }; @@ -2971,14 +2987,25 @@ bool HloParser::ParseParamList() { return ParseToken(TokKind::kRparen, "expects ')' at the end of param list"); } -// dimension_sizes ::= '[' int64_list ']' -bool HloParser::ParseDimensionSizes(std::vector* dimension_sizes) { +// dimension_sizes ::= '[' dimension_list ']' +// dimension_list +// ::= /*empty*/ +// ::= <=? int64 (',' param)* +// param ::= name shape +bool HloParser::ParseDimensionSizes(std::vector* dimension_sizes, + std::vector* dynamic_dimensions) { auto parse_and_add_item = [&]() { tensorflow::int64 i; + bool is_dynamic = false; + if (lexer_.GetKind() == TokKind::kLeq) { + is_dynamic = true; + lexer_.Lex(); + } if (!ParseInt64(&i)) { return false; } dimension_sizes->push_back(i); + dynamic_dimensions->push_back(is_dynamic); return true; }; return ParseList(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma, @@ -3034,12 +3061,18 @@ bool HloParser::ParseShape(Shape* result) { PrimitiveType primitive_type = lexer_.GetPrimitiveTypeVal(); lexer_.Lex(); + // Each element contains a dimension size and a bool indicating whether this + // is a dynamic dimension. std::vector dimension_sizes; - if (!ParseDimensionSizes(&dimension_sizes)) { + std::vector dynamic_dimensions; + if (!ParseDimensionSizes(&dimension_sizes, &dynamic_dimensions)) { return false; } result->set_element_type(primitive_type); - *result->mutable_dimensions() = dimension_sizes; + for (int i = 0; i < dimension_sizes.size(); ++i) { + result->add_dimensions(dimension_sizes[i]); + result->set_dynamic_dimension(i, dynamic_dimensions[i]); + } LayoutUtil::SetToDefaultLayout(result); if (lexer_.GetKind() == TokKind::kw_sparse) { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index ef31cec32770690505b437d8678c45150766e559..6ba16cc82ac1da2a30610d9dfb56cacc100ae05f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -551,6 +551,17 @@ ENTRY %Transpose.v2 () -> s32[1,2,3] { ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2} } +)" +}, +{ +"TransposeC128", +R"(HloModule TransposeC128_module + +ENTRY %Transpose.v3 (input: c128[1,2,3]) -> c128[1,2,3] { + %input = c128[1,2,3]{2,1,0} parameter(0) + ROOT %transpose = c128[1,2,3]{2,1,0} transpose(c128[1,2,3]{2,1,0} %input), dimensions={0,1,2} +} + )" }, // Dynamic slice @@ -566,12 +577,26 @@ ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) - ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[3]{0} %concatenate), dynamic_slice_sizes={2,2,258} } +)" +}, +// Dynamic slice with scalar indices +{ +"DynamicSliceScalarIndices", +R"(HloModule DynamicSlice_module + +ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] { + %original_parameter = s32[2,2,258]{2,1,0} parameter(0) + %constant = s32[] constant(0) + %start_index = s32[] parameter(1) + ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258} +} + )" }, // Dynamic update slice { "DynamicUpdateSlice", -R"(HloModule DynamicUpdateSlice_module +R"(HloModule DynamicSlice_module ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] { %input = s32[1,1,25,1]{3,2,1,0} parameter(0) @@ -580,6 +605,23 @@ ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_ ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[4]{0} %start_indices) } +)" +}, +// Dynamic update slice with scalar indices +{ +"DynamicUpdateSliceScalarIndex", +R"(HloModule DynamicUpdateSlice_module + +ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] { + %input = s32[1,1,25,1]{3,2,1,0} parameter(0) + %update = s32[1,1,2,1]{3,2,1,0} parameter(1) + %start_index.0 = s32[] parameter(2) + %start_index.1 = s32[] parameter(3) + %start_index.2 = s32[] parameter(4) + %start_index.3 = s32[] parameter(5) + ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3) +} + )" }, // batch norm training @@ -1329,20 +1371,20 @@ TEST_P(HloParserTestLongProto, Run) { ExpectEqual(); } TEST_P(HloParserTestShort, Run) { ExpectEqual(); } TEST_P(HloParserTestShortProto, Run) { ExpectEqual(); } -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestLong, - ::testing::ValuesIn(CreateTestCases()), - TestDataToString); -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, - HloParserTestLongProto, - ::testing::ValuesIn(CreateTestCases()), - TestDataToString); -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestShort, - ::testing::ValuesIn(CreateShortTestCases()), - TestDataToString); -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, - HloParserTestShortProto, - ::testing::ValuesIn(CreateShortTestCases()), - TestDataToString); +INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestLong, + ::testing::ValuesIn(CreateTestCases()), + TestDataToString); +INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, + HloParserTestLongProto, + ::testing::ValuesIn(CreateTestCases()), + TestDataToString); +INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestShort, + ::testing::ValuesIn(CreateShortTestCases()), + TestDataToString); +INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, + HloParserTestShortProto, + ::testing::ValuesIn(CreateShortTestCases()), + TestDataToString); class HloParserTest : public ::testing::Test { protected: @@ -2329,5 +2371,25 @@ TEST_F(HloParserTest, ParseInvalidShapeString) { } } +TEST_F(HloParserTest, ParseDynamicArray) { + string shape_string = "f32[123,<=456]"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeShape(F32, {123, 456}, {false, true}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseDynamicTuple) { + string shape_string = "(f32[42], u32[<=123,<=456])"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {42}), + ShapeUtil::MakeShape(U32, {123, 456}, {true, true})}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_fix.h b/tensorflow/compiler/xla/service/hlo_pass_fix.h index 791b1a97b0b82edf19ff1588fd8d5d996ac0fef4..35dc9c0029f9871334cb500c6b71f0c86ab136d7 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_fix.h +++ b/tensorflow/compiler/xla/service/hlo_pass_fix.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_group.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -39,9 +40,36 @@ class HloPassFix : public Pass { int64 iteration_count = 0; int64 limit = std::max(static_cast(1000), module->instruction_count()); + VLOG(3) << "Running HloPassFix."; while (changed_this_iteration) { TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module)); changed |= changed_this_iteration; + VLOG(3) << "changed_this_iteration: " << changed_this_iteration; + ++iteration_count; + if (iteration_count == limit) { + LOG(ERROR) + << "Unexpectedly high number of iterations in HLO passes (" + << iteration_count + << ")\nIf compilation hangs here, please file a bug with XLA."; + } + } + return changed; + } + + StatusOr RunOnModuleGroup(HloModuleGroup* module_group) override { + bool changed = false; + bool changed_this_iteration = true; + int64 iteration_count = 0; + int64 limit = 1000; + for (const HloModule* module : module_group->modules()) { + limit = std::max(limit, module->instruction_count()); + } + VLOG(3) << "Running HloPassFix."; + while (changed_this_iteration) { + TF_ASSIGN_OR_RETURN(changed_this_iteration, + Pass::RunOnModuleGroup(module_group)); + changed |= changed_this_iteration; + VLOG(3) << "changed_this_iteration: " << changed_this_iteration; ++iteration_count; if (iteration_count == limit) { LOG(ERROR) diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 33ce7e23a82d840676bba5f1ca9c0ffc4433465d..ae8c08cf1d16ad6738962f3be7c1b5512110b1d1 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -89,7 +89,7 @@ std::vector HloPassPipeline::GetEnabledPasses( std::vector enabled_passes; for (auto& pass : passes_) { - if (disabled_pass_names.count(string(pass->name())) == 0) { + if (!disabled_pass_names.contains(pass->name())) { enabled_passes.push_back(pass.get()); } } diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.cc b/tensorflow/compiler/xla/service/hlo_profile_printer.cc index 5eb707a957e49d86cdb2f72b72ce750bf29b8fd2..9cc202aa9f5fe5a20a9da05251ea811137ccaadb 100644 --- a/tensorflow/compiler/xla/service/hlo_profile_printer.cc +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_profile_printer.h" +#include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" @@ -34,11 +35,10 @@ string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data, for (const HloComputationInfo& computation_info : hlo_profile_printer_data.computation_infos()) { const auto& instruction_infos = computation_info.instruction_infos(); - bool any_instruction_profiled = - std::any_of(instruction_infos.begin(), instruction_infos.end(), - [&](const HloInstructionInfo& instruction_info) { - return counters[instruction_info.profile_index()] != 0; - }); + bool any_instruction_profiled = absl::c_any_of( + instruction_infos, [&](const HloInstructionInfo& instruction_info) { + return counters[instruction_info.profile_index()] != 0; + }); if (!any_instruction_profiled) { continue; diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index edaa4c59e2674e5f165c468059747d3dd2d54218..0fced7f15bdaf1dbe349e3b0fc6ada68393c6512 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -49,7 +49,7 @@ void HloReachabilityMap::SetReachabilityToUnionHelper( absl::Span inputs, const HloInstruction* instruction, BitVector* bit_vector) { // If instruction is part of inputs, don't reset the bit_vector. - if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) { + if (!absl::c_linear_search(inputs, instruction)) { bit_vector->SetToZero(); } bit_vector->Set(GetIndex(instruction)); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index ac74e2432f2176e13eaf7d4a1934a50ee89d1042..a175e4643de2ac6ce07ac00da914d7ab7acca541 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -57,6 +57,15 @@ using ::tensorflow::strings::HumanReadableNumBytes; // Returns true if the given instruction is rematerializable. bool IsRematerializable(const HloInstruction* instruction) { + if (instruction->opcode() == HloOpcode::kCopy) { + if (LayoutUtil::Equal(instruction->shape().layout(), + instruction->operand(0)->shape().layout())) { + // Don't rematerialize copies added by copy insertion (layout doesn't + // change). + return false; + } + } + // Don't rematerialize instructions with side effects or instructions which // cannot be cloned safely. switch (instruction->opcode()) { @@ -179,7 +188,8 @@ class InstructionList { Item* CreateItem(HloInstruction* inst) { Item* item = new Item; item->instruction = inst; - CHECK(item_map_.insert({inst, item}).second) << "inserting inst twice"; + CHECK(item_map_.insert({inst, item}).second) + << "inserting inst twice " << inst->name(); return item; } @@ -235,8 +245,7 @@ class InstructionList { } // Now scan forwards until we find one of the before_instructions. - while (std::find(before_instructions.begin(), before_instructions.end(), - min_position_item) == before_instructions.end()) { + while (!absl::c_linear_search(before_instructions, min_position_item)) { min_position_item = min_position_item->next; } return InsertBefore(to_insert, min_position_item); @@ -302,7 +311,7 @@ ItemList GetUsers(const InstructionList& instruction_list, // A buffer may be used by the instruction via more than one alias. For // example, a buffer which appears in more than one element of a tuple. Item* user_item = instruction_list.GetItem(user); - if (std::find(users.begin(), users.end(), user_item) == users.end()) { + if (!absl::c_linear_search(users, user_item)) { users.push_back(user_item); } } @@ -418,11 +427,12 @@ class MemoryUsageTracker { // the given uses. Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item, ItemList&& rematerialized_uses) { - CHECK(original_buffer.defining_instruction->placed); - CHECK(!original_buffer.has_indirect_uses); - CHECK(!original_buffer.live_out); + CHECK(original_buffer.defining_instruction->placed) + << original_buffer.defining_instruction->instruction->name(); + CHECK(!original_buffer.has_indirect_uses) << original_buffer.ToString(); + CHECK(!original_buffer.live_out) << original_buffer.ToString(); for (Item* use : rematerialized_uses) { - CHECK(!use->placed); + CHECK(!use->placed) << use->instruction->name(); } return NewBuffer(remat_item, original_buffer.size, std::move(rematerialized_uses), /*live_out=*/false, @@ -456,8 +466,7 @@ class MemoryUsageTracker { return false; } const BufferIdList& in_progress_uses = in_progress_item_->buffers_used; - return std::find(in_progress_uses.begin(), in_progress_uses.end(), - buffer_id) != in_progress_uses.end(); + return absl::c_linear_search(in_progress_uses, buffer_id); } // Returns whether the given instruction is live at the current program @@ -535,8 +544,7 @@ MemoryUsageTracker::MemoryUsageTracker( bool unused; for (Item* user_item : GetUsers(instruction_list_, logical_buffer, points_to_analysis, &unused)) { - if (std::find(buffer->users.begin(), buffer->users.end(), - user_item) == buffer->users.end()) { + if (!absl::c_linear_search(buffer->users, user_item)) { buffer->users.push_back(user_item); buffer->unfinished_user_count++; user_item->buffers_used.push_back(buffer->id); @@ -677,8 +685,8 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, << ", remat_instruction = " << remat_item->instruction->name(); TF_RET_CHECK(in_progress_item_ != nullptr); - TF_RET_CHECK(original_item->placed); - TF_RET_CHECK(!remat_item->placed); + TF_RET_CHECK(original_item->placed) << original_item->instruction->name(); + TF_RET_CHECK(!remat_item->placed) << remat_item->instruction->name(); // Construct the list of buffers used and defined by the rematerialization. remat_item->buffers_used = original_item->buffers_used; @@ -707,7 +715,7 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, ItemList unplaced_users; for (Item* user : old_buffer.users) { if (user->placed) { - CHECK(IsFinished(user)); + CHECK(IsFinished(user)) << user->instruction->name(); placed_users.push_back(user); } else { unplaced_users.push_back(user); @@ -784,8 +792,7 @@ bool MemoryUsageTracker::Check() const { for (const Buffer& buffer : buffers_) { if (buffer.defining_instruction->instruction == instruction) { - CHECK(std::find(defined_buffers.begin(), defined_buffers.end(), - buffer.id) != defined_buffers.end()) + CHECK(absl::c_linear_search(defined_buffers, buffer.id)) << "Instruction " << instruction->name() << " defined buffers is missing: " << buffer.ToString(); } @@ -808,8 +815,7 @@ bool MemoryUsageTracker::Check() const { int64 unfinished_uses = 0; for (Item* user : buffer.users) { const BufferIdList& used_buffers = user->buffers_used; - CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) != - used_buffers.end()) + CHECK(absl::c_linear_search(used_buffers, buffer.id)) << "Instruction " << user->instruction->name() << " used buffers is missing " << buffer.ToString(); if (!IsFinished(user)) { @@ -836,10 +842,10 @@ int64 RematerializationCost(const HloInstruction* instruction, // If none of the users of 'instruction' have been placed in the sequence (as // tracked by memory_tracker), then rematerialization of 'instruction' is a // zero-cost move of 'instruction' in the sequence. - if (!std::any_of(instruction->users().begin(), instruction->users().end(), - [&memory_tracker](const HloInstruction* inst) { - return memory_tracker.IsPlaced(inst); - })) { + if (!absl::c_any_of(instruction->users(), + [&memory_tracker](const HloInstruction* inst) { + return memory_tracker.IsPlaced(inst); + })) { return 0; } @@ -1094,7 +1100,7 @@ StatusOr HloRematerialization::RematerializeComputation( Item* successor_item = instruction_list.GetItem(successor); // Assert to make sure we never remat an operation with control // successor already placed. - CHECK(!successor_item->placed); + CHECK(!successor_item->placed) << successor_item->instruction->name(); place_before.push_back(successor_item); } instruction_list.InsertBeforeInstructions(remat_item, place_before); @@ -1164,7 +1170,7 @@ StatusOr HloRematerialization::RematerializeComputation( // Verify some invariants on the memory tracker. CHECK_EQ(memory_tracker.memory_usage(), 0); for (auto* instruction : computation->instructions()) { - CHECK(memory_tracker.IsPlaced(instruction)); + CHECK(memory_tracker.IsPlaced(instruction)) << instruction->name(); } VLOG(1) << "In computation " << computation->name() << " rematerialized " diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 22c3c40a93a1ddcd36659483fcc79fede32dd2c3..102a360ad8116d8781baf9cb7627a920f4a687c4 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -499,6 +499,52 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { EXPECT_THAT(add_4->operand(0), op::Broadcast(param)); } +TEST_F(HloRematerializationTest, CopyNotRematerialized) { + // Test that copies are not rematerialized. + auto module = CreateNewVerifiedModule(); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vec1024_shape_, "param")); + + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kCopy, param)); + + auto negate_a_1 = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, copy)); + + auto negate_a_2 = builder.AddInstruction(HloInstruction::CreateUnary( + vec1024_shape_, HloOpcode::kNegate, negate_a_1)); + + auto negate_b_1 = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, copy)); + + auto negate_b_2 = builder.AddInstruction(HloInstruction::CreateUnary( + vec1024_shape_, HloOpcode::kNegate, negate_b_1)); + + builder.AddInstruction(HloInstruction::CreateTuple({negate_a_2, negate_b_2})); + + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/1 * 1024, module.get())); + + auto count_copies = [](const HloComputation* computation) { + int64 copy_count = 0; + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + copy_count++; + } + } + return copy_count; + }; + EXPECT_TRUE(changed); + + EXPECT_EQ(count_copies(entry_computation), 1); +} + class IndirectUseTest : public HloRematerializationTest, public ::testing::WithParamInterface {}; @@ -588,8 +634,8 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { } } -INSTANTIATE_TEST_CASE_P(IndirectUseTestInstantiation, IndirectUseTest, - ::testing::Values(true, false)); +INSTANTIATE_TEST_SUITE_P(IndirectUseTestInstantiation, IndirectUseTest, + ::testing::Values(true, false)); } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 5a9b820a9d7f58695383b21c9e2126cf98970c83..d7d66ae1c4592723ca991d5ee971fa72cc1af90a 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -383,9 +383,7 @@ ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice( if (device_assignment != nullptr) { run_options.set_device_assignment(device_assignment); } - return ServiceExecutableRunOptions( - run_options, backend().StreamBorrower(), - /*xla_intra_op_thread_pool=*/backend().eigen_intra_op_thread_pool()); + return ServiceExecutableRunOptions(run_options, backend().StreamBorrower()); } Backend& HloRunner::backend() { diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc index 8f6eb974c5179b420c8f961393ca923e0a3b3530..e75373501cffac6a736be89e9f6139b6ff2cdbc1 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -140,7 +140,7 @@ Status HloSchedule::UpdateComputationSchedule( std::queue worklist; for (HloInstruction* instruction : computation->instructions()) { - if (ids_in_schedule.count(instruction->unique_id()) == 0) { + if (!ids_in_schedule.contains(instruction->unique_id())) { // This is a newly added instruction which is not in the schedule. if (instruction->operands().empty()) { worklist.push(instruction); @@ -204,7 +204,7 @@ Status HloSchedule::Update() { std::vector nonfusion_computations = module_->MakeNonfusionComputations(); for (const HloComputation* computation : nonfusion_computations) { - TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + TF_RET_CHECK(sequences_.contains(computation->unique_id())) << "Computation " << computation->name() << " not in HloSchedule."; } if (sequences_.size() > nonfusion_computations.size()) { @@ -215,7 +215,7 @@ Status HloSchedule::Update() { nonfusion_computations_ids.insert(computation->unique_id()); } for (auto it = sequences_.begin(); it != sequences_.end();) { - if (nonfusion_computations_ids.count(it->first) == 0) { + if (!nonfusion_computations_ids.contains(it->first)) { sequences_.erase(it++); } else { ++it; @@ -244,7 +244,7 @@ Status HloSchedule::Verify() const { << "Schedule has " << sequences_.size() << " sequences, but module has " << nonfusion_computations.size() << " non-fusion computations"; for (const HloComputation* computation : nonfusion_computations) { - TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + TF_RET_CHECK(sequences_.contains(computation->unique_id())) << "Computation " << computation->name() << " missing from HLO schedule."; } @@ -268,7 +268,7 @@ Status HloSchedule::Verify() const { << instruction_position.size() << " instructions, expected " << computation->instruction_count(); for (const HloInstruction* instruction : computation->instructions()) { - TF_RET_CHECK(instruction_position.count(instruction) == 1) + TF_RET_CHECK(instruction_position.contains(instruction)) << "Instruction " << instruction->name() << " is not in schedule"; } diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h index 486ddbf499de80c634bc497158cd79ca066cc866..a5f54ae2c33259d080631061dff9ae40b41495dc 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -110,7 +110,7 @@ class HloSchedule { // Returns true if the schedule has a sequence for the given computation. bool is_computation_scheduled(const HloComputation* computation) const { - return sequences_.count(computation->unique_id()) == 1; + return sequences_.contains(computation->unique_id()); } // Updates the schedule such that it is (again) a valid schedule for the diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 70a860c356ca2fb1c4c973ea3d96c50fabc2c7c2..37cc146bd7a6f2aef9373bd4afd8572ffac6473c 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/overflow_util.h" @@ -30,7 +31,7 @@ HloSharding HloSharding::AssignDevice(int64 device_id) { } HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { - CHECK_EQ(1, ShapeUtil::Rank(input_shape)); + CHECK_EQ(1, input_shape.rank()); CHECK_GT(num_tiles, 1); std::vector dimensions(1, num_tiles); Array assignment(dimensions); @@ -57,7 +58,7 @@ HloSharding HloSharding::Tuple(const ShapeTree& sub_shardings) { HloSharding HloSharding::Tuple(const Shape& tuple_shape, absl::Span shardings) { - CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape); for (auto& sharding : shardings) { CHECK(!sharding.IsTuple()) << sharding.ToString(); } @@ -70,7 +71,7 @@ HloSharding HloSharding::Tuple(const Shape& tuple_shape, HloSharding HloSharding::SingleTuple(const Shape& tuple_shape, const HloSharding& sharding) { - CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape); CHECK(!sharding.IsTuple()) << sharding.ToString(); int64 leaf_count = RequiredLeaves(tuple_shape); std::vector flattened_list; @@ -80,7 +81,7 @@ HloSharding HloSharding::SingleTuple(const Shape& tuple_shape, HloSharding HloSharding::Single(const Shape& shape, const HloSharding& sharding) { - return ShapeUtil::IsTuple(shape) ? SingleTuple(shape, sharding) : sharding; + return shape.IsTuple() ? SingleTuple(shape, sharding) : sharding; } string HloSharding::ToString() const { @@ -106,13 +107,12 @@ string HloSharding::ToString() const { bool HloSharding::UsesDevice(int64 device) const { if (IsTuple()) { - return std::any_of( - tuple_elements_.begin(), tuple_elements_.end(), - [&](const HloSharding& s) { return s.UsesDevice(device); }); + return absl::c_any_of(tuple_elements_, [&](const HloSharding& s) { + return s.UsesDevice(device); + }); } const auto& devices = tile_assignment_; - return replicated_ || - std::find(devices.begin(), devices.end(), device) != devices.end(); + return replicated_ || absl::c_linear_search(devices, device); } std::map HloSharding::UsedDevices(int64* count) const { @@ -269,7 +269,7 @@ int64 HloSharding::GetUniqueDevice() const { } Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { - if (!ShapeUtil::IsTuple(shape)) { + if (!shape.IsTuple()) { return tensorflow::errors::InvalidArgument( StrCat("Sharding is tuple-shaped but validation shape is not.")); } @@ -305,7 +305,7 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { Status HloSharding::ValidateNonTuple(const Shape& shape, int64 num_devices) const { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { return tensorflow::errors::InvalidArgument( StrCat("Validation shape is a tuple but sharding is not.")); } @@ -316,7 +316,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, // All tile assignments must be less than the number of available cores and // unique. Status status = Status::OK(); - std::set seen_cores; + absl::flat_hash_set seen_cores; tile_assignment_.Each( [&](absl::Span indices, int32 core) { // Don't overwrite a bad status, so we report the first error. @@ -324,7 +324,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, if (core >= num_devices) { status = tensorflow::errors::InvalidArgument(StrCat( "core ", core, " > ", num_devices, " in tile assignment")); - } else if (seen_cores.count(core) != 0) { + } else if (seen_cores.contains(core)) { status = tensorflow::errors::InvalidArgument( StrCat("core ", core, " is not unique in tile assignment")); } @@ -340,7 +340,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, } // The tile assignment tensor must have the same rank as the input. - if (ShapeUtil::Rank(shape) != tile_assignment_.num_dimensions()) { + if (shape.rank() != tile_assignment_.num_dimensions()) { return tensorflow::errors::InvalidArgument( "Number of tile assignment dimensions is different to the input rank. " "sharding=", @@ -437,8 +437,8 @@ Shape HloSharding::TileShape(const Shape& shape) const { } Shape result_shape = shape; for (int64 i = 0; i < shape.dimensions_size(); ++i) { - (*result_shape.mutable_dimensions())[i] = - CeilOfRatio(shape.dimensions(i), tile_assignment_.dim(i)); + result_shape.set_dimensions( + i, CeilOfRatio(shape.dimensions(i), tile_assignment_.dim(i))); } return result_shape; } @@ -455,7 +455,7 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape, } sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx}); } - if (ShapeUtil::IsTuple(*sub_shape)) { + if (sub_shape->IsTuple()) { auto begin_it = tuple_elements_.begin() + sharding_index; std::vector sub_shardings( begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape)); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 9775505f8608ced3e33abe376f4922cc6a972726..5789ae09988d2a85247c5b8c037a172b3699f3b7 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -101,8 +101,8 @@ class HloSharding { if (!IsTuple()) { return replicated_; } - return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), - [](const HloSharding& s) { return s.IsReplicated(); }); + return absl::c_all_of( + tuple_elements_, [](const HloSharding& s) { return s.IsReplicated(); }); } // Returns true if the tile size is the same as the input size. @@ -110,8 +110,9 @@ class HloSharding { if (!IsTuple()) { return maximal_; } - return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), - [](const HloSharding& s) { return s.IsTileMaximal(); }); + return absl::c_all_of(tuple_elements_, [](const HloSharding& s) { + return s.IsTileMaximal(); + }); } // Returns true if the sharding defines an operation on the given device. diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index f5061304456e04ab40448861343ef201c9450dcf..094d98bc6e54028557f6d38cd165bf34e1fb8c46 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -99,7 +99,7 @@ std::vector LocatePassThroughDomainLinks( << "Instruction is not a kDomain: " << instruction->ToString(); for (HloInstruction* user : instruction->users()) { if (user->opcode() == HloOpcode::kDomain && - domain.exit_domains.count(user) != 0) { + domain.exit_domains.contains(user)) { pass_through.emplace_back(user, instruction); VLOG(2) << "Found passthrough domain link:"; VLOG(2) << " " << user->ToString(); @@ -234,7 +234,7 @@ StatusOr ApplyShardingFromUsers(HloInstruction* instruction, if (instruction->users().empty()) { // No sharding from users, use domain_sharding, after checking // compatibility. - TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape()) && + TF_RET_CHECK(instruction->shape().IsTuple() && ShapeUtil::GetLeafCount(instruction->shape()) == domain_sharding.tuple_elements().size()); instruction->set_sharding(domain_sharding); @@ -253,7 +253,7 @@ StatusOr ApplyShardingFromUsers(HloInstruction* instruction, instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice)); for (HloInstruction* user : instruction->users()) { if (user->opcode() == HloOpcode::kDomain && - domain.exit_domains.count(user) > 0) { + domain.exit_domains.contains(user)) { // If a user is a domain and it is registered in the domain exits, then // the instruction sharding is taken directly from the domain, and no // further users need to be visited. @@ -266,7 +266,7 @@ StatusOr ApplyShardingFromUsers(HloInstruction* instruction, AssignmentKind sub_assigned = AssignmentKind::kUnassigned; TF_ASSIGN_OR_RETURN(ShapeTree user_sharding_tree, GetShardingTreeFromUser(*instruction, *user)); - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { // For tuple-shaped instructions collect individual tuple subshardings // from the uses, and then combine them into the tuple sharding. // If the user is a GTE its sharding concerns only the subtree of @@ -298,7 +298,7 @@ StatusOr ApplyShardingFromUsers(HloInstruction* instruction, } if (assigned == AssignmentKind::kAssigned) { - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { instruction->set_sharding(HloSharding::Tuple(sharding_tree)); } else { TF_RET_CHECK(sharding_tree.leaf_count() == 1); @@ -361,7 +361,7 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain, // kUnassignedDevice. Indeed in case of doubt it is better to leave the // entire tuple unassigned, and let the device placer decide for it. if (instruction->sharding().UsesDevice(kUnassignedDevice)) { - TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape())) + TF_RET_CHECK(instruction->shape().IsTuple()) << "Only tuples can have kUnassignedDevice sub shardings"; instruction->clear_sharding(); } diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 487653344976a10e18ba667085525ba1ecbb8612..c1f69db74eafb7743e85f499f2f4828ed0375501 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -61,8 +61,7 @@ void CleanNodeName(string* name) { name->erase(std::remove(name->begin(), name->end(), '%'), name->end()); const string chars_to_replace = "<>[]"; auto pred = [&](char c) { - return std::find(chars_to_replace.begin(), chars_to_replace.end(), c) != - chars_to_replace.end(); + return absl::c_linear_search(chars_to_replace, c); }; std::replace_if(name->begin(), name->end(), pred, '_'); } @@ -159,7 +158,7 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, // Set the layout. if (LayoutUtil::HasLayout(instruction->shape())) { string layout_string; - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { // For tuples, emit the full shape because the layout of a tuple is not // represented in a single Layout field. layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 59594ab2f0f70a206c73e998dbfa69c2c5c7ba43..218b33b2ac2b86edc30b2f014ba206c71da37682 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -46,7 +46,7 @@ const Shape& HloPosition::shape() const { string HloPosition::ToString() const { string index_str = - ShapeUtil::IsTuple(instruction->shape()) ? (" " + index.ToString()) : ""; + instruction->shape().IsTuple() ? (" " + index.ToString()) : ""; return StrCat(instruction->name(), index_str); } @@ -56,10 +56,9 @@ std::ostream& operator<<(std::ostream& out, const HloPosition& position) { } string HloUse::ToString() const { - string index_str = - ShapeUtil::IsTuple(instruction->operand(operand_number)->shape()) - ? (" " + operand_index.ToString()) - : ""; + string index_str = instruction->operand(operand_number)->shape().IsTuple() + ? (" " + operand_index.ToString()) + : ""; return StrCat(instruction->name(), ", operand ", operand_number, index_str); } @@ -88,7 +87,7 @@ bool HloValue::operator!=(const HloValue& other) const { } string HloValue::ToShortString() const { - string index_str = ShapeUtil::IsTuple(defining_instruction()->shape()) + string index_str = defining_instruction()->shape().IsTuple() ? defining_index().ToString() : ""; return StrCat(id(), " ", is_phi_ ? "PHI " : "", @@ -210,7 +209,7 @@ std::ostream& operator<<(std::ostream& out, const HloValue& value) { } void HloValueSet::SortAndUniquifyValues() { - std::sort(values_.begin(), values_.end(), HloValue::IdLessThan); + absl::c_sort(values_, HloValue::IdLessThan); values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual), values_.end()); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index e1c737132f72948e0e46d37dd08ddf8e7b29bfca..144c01eac1c06bb067c9f29f29b536c459ea273e 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -349,7 +349,10 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) { Status ShapeVerifier::HandleIota(HloInstruction* instruction) { TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0)); auto* iota = Cast(instruction); - const int64 rank = ShapeUtil::Rank(iota->shape()); + if (!iota->shape().IsArray()) { + return InternalError("Iota does not support non-array result."); + } + const int64 rank = iota->shape().rank(); if (rank == 0) { return InternalError("Iota does not support scalars."); } @@ -387,6 +390,14 @@ Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { TF_RETURN_IF_ERROR(CheckOperandCount(bitcast, 1)); + // Bitcasts are not allowed to change the element type. + if (bitcast->operand(0)->shape().element_type() != + bitcast->shape().element_type()) { + return InternalError( + "Bitcast can not change the element type from %s to %s", + PrimitiveType_Name(bitcast->operand(0)->shape().element_type()), + PrimitiveType_Name(bitcast->shape().element_type())); + } return Status::OK(); } @@ -397,13 +408,11 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { const Shape& operand_shape = broadcast->operand(0)->shape(); // Check for mixed precision. TF_RET_CHECK(SameElementType(broadcast->shape(), operand_shape)); - TF_RET_CHECK(ShapeUtil::Rank(operand_shape) == - broadcast->dimensions().size()); - for (int64 operand_dimension = 0; - operand_dimension < ShapeUtil::Rank(operand_shape); + TF_RET_CHECK(operand_shape.rank() == broadcast->dimensions().size()); + for (int64 operand_dimension = 0; operand_dimension < operand_shape.rank(); ++operand_dimension) { int64 output_dimension = broadcast->dimensions()[operand_dimension]; - TF_RET_CHECK((output_dimension < ShapeUtil::Rank(broadcast->shape())) && + TF_RET_CHECK((output_dimension < broadcast->shape().rank()) && output_dimension >= 0 && (broadcast->shape().dimensions(output_dimension) == operand_shape.dimensions(operand_dimension))) @@ -498,21 +507,23 @@ Status ShapeVerifier::HandleSlice(HloInstruction* slice) { } Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) { - TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_slice, 2)); - return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape( - dynamic_slice->operand(0)->shape(), - dynamic_slice->operand(1)->shape(), - dynamic_slice->dynamic_slice_sizes())); + return CheckShape( + dynamic_slice, + ShapeInference::InferDynamicSliceShape( + dynamic_slice->operand(0)->shape(), + Cast(dynamic_slice)->index_shapes(), + dynamic_slice->dynamic_slice_sizes())); } Status ShapeVerifier::HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) { - TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_update_slice, 3)); - return CheckShape(dynamic_update_slice, - ShapeInference::InferDynamicUpdateSliceShape( - dynamic_update_slice->operand(0)->shape(), - dynamic_update_slice->operand(1)->shape(), - dynamic_update_slice->operand(2)->shape())); + return CheckShape( + dynamic_update_slice, + ShapeInference::InferDynamicUpdateSliceShape( + dynamic_update_slice->operand(0)->shape(), + dynamic_update_slice->operand(1)->shape(), + Cast(dynamic_update_slice) + ->index_shapes())); } Status ShapeVerifier::HandleTuple(HloInstruction* tuple) { @@ -524,8 +535,7 @@ Status ShapeVerifier::HandleMap(HloInstruction* map) { int64 max_operand_rank = 0; for (const HloInstruction* operand : map->operands()) { operand_shapes.push_back(&operand->shape()); - max_operand_rank = - std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); + max_operand_rank = std::max(max_operand_rank, operand->shape().rank()); } // TODO(b/65689298) Remove code below once Map is generalized to accept // arbitrary map dimensions. @@ -695,7 +705,6 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kRecv: case HloOpcode::kRecvDone: case HloOpcode::kReducePrecision: - case HloOpcode::kSelect: case HloOpcode::kTupleSelect: case HloOpcode::kSend: case HloOpcode::kSendDone: @@ -983,7 +992,7 @@ bool ShapeContainsToken(const Shape& shape) { bool contains_token = false; ShapeUtil::ForEachSubshape( shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { - if (ShapeUtil::IsToken(subshape)) { + if (subshape.IsToken()) { contains_token = true; } }); @@ -1271,11 +1280,11 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I // or ComputationLowerer::Visit() TF_RET_CHECK(broadcast->dimensions().size() == - ShapeUtil::Rank(broadcast->operand(0)->shape())) + broadcast->operand(0)->shape().rank()) << "Broadcast HLO (" << broadcast->ToShortString() << ") has invalid number of dimensions: " << broadcast->dimensions().size() - << " != " << ShapeUtil::Rank(broadcast->operand(0)->shape()); + << " != " << broadcast->operand(0)->shape().rank(); return Status::OK(); } @@ -1325,7 +1334,7 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { } Status HandleGetTupleElement(HloInstruction* gte) override { - TF_RET_CHECK(ShapeUtil::IsTuple(gte->operand(0)->shape())); + TF_RET_CHECK(gte->operand(0)->shape().IsTuple()); return Status::OK(); } @@ -1376,7 +1385,7 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { for (HloInstruction* operand : instruction->operands()) { const Shape& operand_shape = operand->shape(); if (LayoutUtil::IsDenseArray(operand_shape) && - ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(result_shape)) { + operand_shape.rank() == result_shape.rank()) { const Layout& operand_layout = operand_shape.layout(); TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout)) << "Instruction shouldn't change layouts " diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index a1a6aba9728c137d17487b5914f67cb3966fc12b..479905b317d5639ff2cebc4d1044e21b527693f6 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -168,8 +168,13 @@ class ShapeVerifier : public DfsHloVisitor { // An interface used to encapsulate target-specific verification quirks. class TargetVerifierMetadata { public: + TargetVerifierMetadata(std::function shape_size_function) + : shape_size_function_(shape_size_function) {} + // Returns a target-specific shape size. - virtual int64 ShapeSize(const Shape& shape) const = 0; + int64 ShapeSize(const Shape& shape) const { + return shape_size_function_(shape); + } virtual std::unique_ptr GetVerifier() const = 0; @@ -178,20 +183,23 @@ class TargetVerifierMetadata { TargetVerifierMetadata(const TargetVerifierMetadata&) = delete; TargetVerifierMetadata& operator=(const TargetVerifierMetadata&) = delete; + + private: + // Returns a target-specific shape size. + std::function shape_size_function_; }; // The default implementation of TargetVerifierMetadata, used unless the target // needs to override it. class DefaultVerifierMetadata : public TargetVerifierMetadata { public: - DefaultVerifierMetadata(bool layout_sensitive, bool allow_mixed_precision) - : layout_sensitive_(layout_sensitive), + DefaultVerifierMetadata( + bool layout_sensitive, bool allow_mixed_precision, + std::function shape_size_function) + : TargetVerifierMetadata(shape_size_function), + layout_sensitive_(layout_sensitive), allow_mixed_precision_(allow_mixed_precision) {} - int64 ShapeSize(const Shape& shape) const override { - return ShapeUtil::ByteSizeOf(shape); - } - // Creates a ShapeVerifier that checks that shapes match inferred // expectations. This creates a new verifier every time because ShapeVerifier, // being a DfsHloVisitor, is stateful. We want a clean object for each run of @@ -210,11 +218,14 @@ class DefaultVerifierMetadata : public TargetVerifierMetadata { // the module. class HloVerifier : public HloModulePass { public: - explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision, - std::function - instruction_can_change_layout_func = {}) + explicit HloVerifier( + bool layout_sensitive, bool allow_mixed_precision, + std::function + instruction_can_change_layout_func = {}, + std::function shape_size_func = + [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); }) : target_metadata_(absl::make_unique( - layout_sensitive, allow_mixed_precision)), + layout_sensitive, allow_mixed_precision, shape_size_func)), instruction_can_change_layout_func_( std::move(instruction_can_change_layout_func)) { CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive); diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 4bc557e4e62e7df4e25fda86fe417e84129b464c..4f69bd155b8713041ba539098808125956e86259 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" @@ -27,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -386,6 +388,55 @@ TEST_F(HloVerifierTest, AddWithLayoutChange) { ASSERT_TRUE(status.ok()); } +TEST_F(HloVerifierTest, ScalarIndexDynamicSlice) { + const char* const kScalarIndexDynamicSlice = R"( + HloModule DynamicSlice_module + + ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] { + %original_parameter = s32[2,2,258] parameter(0) + %constant = s32[] constant(0) + %start_index = s32[] parameter(1) + ROOT %dynamic-slice = s32[2,2,258] dynamic-slice(s32[2,2,258] %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258} + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kScalarIndexDynamicSlice, config)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTest, ScalarIndexDynamicUpdateSlice) { + const char* const kScalarIndexDynamicSlice = R"( + HloModule DynamicUpdateSlice_module + + ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] { + %input = s32[1,1,25,1]{3,2,1,0} parameter(0) + %update = s32[1,1,2,1]{3,2,1,0} parameter(1) + %start_index.0 = s32[] parameter(2) + %start_index.1 = s32[] parameter(3) + %start_index.2 = s32[] parameter(4) + %start_index.3 = s32[] parameter(5) + ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3) + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kScalarIndexDynamicSlice, config)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo)); auto status = verifier().Run(module.get()).status(); @@ -399,8 +450,9 @@ TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) { HloModule SliceWithLayoutChange ENTRY SliceWithLayoutChange { par0 = f32[4,5]{0,1} parameter(0) - par1 = s32[2] parameter(1) - ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1), + par1 = s32[] parameter(1) + par2 = s32[] parameter(2) + ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1, par2), dynamic_slice_sizes={3,4} } )"; @@ -429,5 +481,76 @@ TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) { EXPECT_THAT(status.error_message(), HasSubstr("Instruction shouldn't change layouts")); } + +TEST_F(HloVerifierTest, BitcastCanNotChangeElementType) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY BitcastCanNotChangeElementType { + constant.0 = f32[2] constant({0.0, 0.0}) + ROOT bitcast = s32[2] bitcast(constant.0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Bitcast can not change the element type")); +} + +TEST_F(HloVerifierTest, SelectMixedPrecisionNotAllowed) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY SelectMixedPrecisionNotAllowed { + p0 = pred[] parameter(0) + p1 = f32[32] parameter(1) + p2 = bf16[32] parameter(2) + ROOT select = f32[32] select(p0, p1, p2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Seen floating point types of different precisions")); +} + +TEST_F(HloVerifierTestAllowMixedPrecision, SelectMixedPrecisionAllowed) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY SelectMixedPrecisionAllowed { + p0 = pred[] parameter(0) + p1 = f32[32] parameter(1) + p2 = bf16[32] parameter(2) + ROOT select = f32[32] select(p0, p1, p2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTest, IotaNonArrayResult) { + const char* const hlo_string = R"( + HloModule IotaTupleResult + + ENTRY kernelEntry { + ROOT iota = () iota(), iota_dimension=24 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("does not support non-array result")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index 90904ac00110457bcc3b8974816a7080c4ab89fc..88fc62bd1e2a7830b3f61738a8642308ef4225a7 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -128,9 +128,9 @@ string HumanReadableProfileBuilder::ToString() const { // Sort ops in decreasing order of cycles, and print them. std::vector sorted_ops(op_infos_); - std::sort( - sorted_ops.begin(), sorted_ops.end(), - [](const OpInfo& a, const OpInfo& b) { return a.cycles > b.cycles; }); + absl::c_sort(sorted_ops, [](const OpInfo& a, const OpInfo& b) { + return a.cycles > b.cycles; + }); for (const auto& op : sorted_ops) { print_op(op); } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 1ebb3319779c00fd4afe90606bf336e16349429d..76bf48870d55e82497ba5f63e9e2e2a322cb330e 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -103,7 +103,7 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache( do { const HloInstruction* instr = stack.back(); - if (cache_.count(instr)) { + if (cache_.contains(instr)) { stack.pop_back(); continue; } @@ -111,9 +111,9 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache( switch (FindOrDie(dfs_state_map, instr)) { case kDiscovered: { for (const HloInstruction* operand : instr->operands()) { - if (!cache_.count(operand)) { + if (!cache_.contains(operand)) { stack.push_back(operand); - CHECK(!dfs_state_map.count(operand) || + CHECK(!dfs_state_map.contains(operand) || dfs_state_map[operand] == kDiscovered); dfs_state_map[operand] = kDiscovered; } @@ -1002,7 +1002,7 @@ bool CanFoldDotIntoIndexedArray( absl::Span contracting_dims, absl::Span batch_dims) { absl::optional non_contracting_non_batch_dim = - GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()), + GetOnlyNonContractingNonBatchDim(indexed_array->shape().rank(), contracting_dims, batch_dims); if (!non_contracting_non_batch_dim.has_value()) { VLOG(3) << tag << ": multiple or no non-contracting non-batch dimensions"; @@ -1015,7 +1015,7 @@ bool CanFoldDotIntoIndexedArray( return false; } - int64 indexed_array_rank = ShapeUtil::Rank(indexed_array->shape()); + int64 indexed_array_rank = indexed_array->shape().rank(); if (indexed_array->source_dim() < (indexed_array_rank - 2)) { // This restriction can be lifted by inserting reshape nodes. VLOG(3) << tag @@ -1043,7 +1043,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( return nullptr; } - int64 lhs_rank = ShapeUtil::Rank(lhs->shape()); + int64 lhs_rank = lhs->shape().rank(); DotDimensionNumbers new_dim_numbers = dim_numbers; new_dim_numbers.set_lhs_contracting_dimensions( 0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1)); @@ -1078,7 +1078,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( return nullptr; } - int64 rhs_rank = ShapeUtil::Rank(rhs->shape()); + int64 rhs_rank = rhs->shape().rank(); DotDimensionNumbers new_dim_numbers = dim_numbers; new_dim_numbers.set_rhs_contracting_dimensions( diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 295465c8481bcb7d1385192febe0d09614e393b3..62107b5a88d4e37552fa5a6384700a9291a9c655 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" +#include "absl/strings/ascii.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -43,7 +42,7 @@ class IndexedArrayAnalysisTest : public HloTestBase { string result; for (char c : text) { - if (!isspace(c)) { + if (!absl::ascii_isspace(c)) { result.push_back(c); } else if (!result.empty() && result.back() != ' ') { result.push_back(' '); diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 07448715293ca8dde5492a054b84c3408004bdaf..b97060535d998e174639dceca5cde517cef01e30 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -174,23 +174,22 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { ShapeUtil::ForEachSubshape( hlo->shape(), [&output_rank](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { output_rank = std::max(output_rank, ShapeUtil::TrueRank(subshape)); } }); - return std::count_if(hlo->operands().begin(), hlo->operands().end(), - [output_rank](HloInstruction* operand) { - if (operand->opcode() == HloOpcode::kBroadcast || - operand->opcode() == HloOpcode::kIota) { - return false; - } - if (operand->opcode() == HloOpcode::kConstant && - ShapeUtil::IsEffectiveScalar(operand->shape())) { - return false; - } - return ShapeUtil::TrueRank(operand->shape()) >= - output_rank; - }) <= 1; + return absl::c_count_if( + hlo->operands(), [output_rank](HloInstruction* operand) { + if (operand->opcode() == HloOpcode::kBroadcast || + operand->opcode() == HloOpcode::kIota) { + return false; + } + if (operand->opcode() == HloOpcode::kConstant && + ShapeUtil::IsEffectiveScalar(operand->shape())) { + return false; + } + return ShapeUtil::TrueRank(operand->shape()) >= output_rank; + }) <= 1; } bool InstructionFusion::CanFuseOnAllPaths( @@ -274,7 +273,7 @@ InstructionFusion::ComputeGloballyUnfusible( ShapeUtil::ForEachSubshape( shape, [&size](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { size += ShapeUtil::ElementsIn(subshape); } }); @@ -409,9 +408,8 @@ class ReversePostOrderFusionQueue : public FusionQueue { } sorted_operand_numbers.push_back(i); } - std::sort( - sorted_operand_numbers.begin(), sorted_operand_numbers.end(), - [&](int64 i, int64 j) { + absl::c_sort( + sorted_operand_numbers, [&](int64 i, int64 j) { // Instructions with higher priority in the queue come first. return ( FindOrDie(post_order_index_, instruction->mutable_operand(i)) > diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index a981d94a999e3d322986bc2bfd56a5b0b5d175fc..a305c6e8005045f7dbca3b8099a3b8ddebb092af 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -1,12 +1,12 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//visibility:public"]) - load( "//tensorflow/core:platform/default/build_config_root.bzl", "if_static", ) +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + cc_library( name = "interpreter_transfer_manager", srcs = ["interpreter_transfer_manager.cc"], @@ -34,6 +34,7 @@ cc_library( "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -47,8 +48,10 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/compiler/xla/service:map_inliner", + "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:while_loop_simplifier", + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/core:lib", "//tensorflow/stream_executor", "@com_google_absl//absl/memory", @@ -115,6 +118,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_headers_lib", + "//tensorflow/stream_executor/host:host_stream", + "//tensorflow/stream_executor/host:host_timer", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index d37ae94bf6c4c697bbf30390c02a5099271e00a4..0827b1daf89bebb68c045784ef2b9da677792880 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -21,6 +21,8 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" @@ -31,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/interpreter/executable.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/compiler/xla/service/map_inliner.h" +#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -40,12 +43,50 @@ limitations under the License. namespace xla { namespace interpreter { +namespace { + +// Handles custom_call ops during evaluation by routing them through the global +// CPU registry used by other CPU-based backends. +StatusOr HandleEvaluatorCustomCall( + HloInstruction* custom_call, absl::Span operands) { + // Find the target C function in the global registry. + auto* registry = xla::cpu::CustomCallTargetRegistry::Global(); + void* target_fn = registry->Lookup(custom_call->custom_call_target()); + if (!target_fn) { + return NotFound("Custom call target '%s' was not registered", + custom_call->custom_call_target()); + } + + // Populate pointers to operand and output literal data. + std::vector operand_data; + operand_data.reserve(operands.size()); + for (const auto* literal : operands) { + operand_data.push_back(literal->untyped_data()); + } + auto output = Literal::CreateFromShape(custom_call->shape()); + void* output_data = output.untyped_data(); + + // Call the target function matching the C ABI used by the CPU backends. + auto* typed_fn = reinterpret_cast(target_fn); + (*typed_fn)(output_data, operand_data.data()); + + return std::move(output); +} + +} // namespace + Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); + pipeline.AddPass(); pipeline.AddPass( hlo_module->mutable_entry_computation_layout(), LayoutAssignment::InstructionCanChangeLayout); + + ReducePrecisionInsertion::AddPasses( + &pipeline, hlo_module->config().debug_options(), + ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); + return pipeline.Run(hlo_module).status(); } @@ -75,10 +116,12 @@ StatusOr> InterpreterCompiler::RunBackend( // In this case we are using an HloEvaluator at execution time, so we don't // need to compile anything - // Create executable from only the Hlo module. auto evaluator = absl::make_unique(); evaluator->set_use_fast_path( hlo_module->config().debug_options().xla_hlo_evaluator_use_fast_path()); + evaluator->set_custom_call_handler(HandleEvaluatorCustomCall); + + // Create executable from only the Hlo module. std::unique_ptr executable = absl::make_unique(std::move(hlo_module), std::move(evaluator)); diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index de9204011ce5ba8a9fc2871c6bd7120b6ed371b5..7a6ebdef708bcc3a92fbd8618db0c42c35e6ce8b 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -68,6 +68,18 @@ StatusOr InterpreterExecutable::ExecuteOnStream( "Mismatch between argument count and graph parameter count."); } + // Check that the args have the right shape. + for (int64 i = 0; i < computation->num_parameters(); ++i) { + const auto& expected_shape = computation->parameter_instruction(i)->shape(); + const auto& actual_shape = arguments[i]->on_device_shape(); + if (!ShapeUtil::Equal(expected_shape, actual_shape)) { + return InvalidArgument( + "Shape mismatch on parameter %d. Expected %s, but was %s.", i, + ShapeUtil::HumanString(expected_shape), + ShapeUtil::HumanString(actual_shape)); + } + } + TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager, TransferManager::GetForPlatform(platform)); @@ -86,8 +98,8 @@ StatusOr InterpreterExecutable::ExecuteOnStream( { tensorflow::mutex_lock lock(evaluator_lock_); evaluator_->ResetVisitStates(); - TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate( - *computation, arg_literals)); + TF_ASSIGN_OR_RETURN(result_literal, + evaluator_->Evaluate(*computation, arg_literals)); } // Transform the result literal back into a ShapedBuffer. @@ -117,7 +129,7 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream( } /*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) { - if (ShapeUtil::IsOpaque(shape)) { + if (shape.IsOpaque()) { return sizeof(void*); } return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index b9ddd9636fe29e85092ed67fc644a54332b218d3..5d9e3392fd86c587a0bd998a282c52d145cc710e 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -147,12 +147,9 @@ bool LayoutConstraints::OperandBufferForwarded( PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction); PointsToSet::BufferSet* operand_buffers = GetBufferSet(instruction->operand(operand_no)); - for (const LogicalBuffer* output_buffer : *output_buffers) { - if (operand_buffers->count(output_buffer) > 0) { - return true; - } - } - return false; + return absl::c_any_of(*output_buffers, [&](const LogicalBuffer* b) { + return operand_buffers->count(b) > 0; + }); } Status LayoutConstraints::SetBufferLayout(const Layout& layout, @@ -256,7 +253,7 @@ Status LayoutConstraints::SetArrayOperandLayout( const Layout& layout, const HloInstruction* instruction, int64 operand_no, bool mandatory, bool dfs) { const HloInstruction* operand = instruction->operand(operand_no); - TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); + TF_RET_CHECK(operand->shape().IsArray()); Shape shape(operand->shape()); *shape.mutable_layout() = layout; TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape)); @@ -314,7 +311,7 @@ Status LayoutConstraints::SetInstructionLayout( CHECK_EQ(1, buffers.size()); CHECK_EQ(buffers[0]->instruction(), instruction); - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { return SetBufferLayout(subshape.layout(), *buffers[0], mandatory); } else { return Status::OK(); @@ -406,7 +403,7 @@ Status LayoutAssignment::BuildHostChannelConstraints( instruction->opcode() == HloOpcode::kRecv) { const Shape& data_shape = ShapeUtil::GetTupleElementShape(send_recv_instr->shape(), 0); - TF_RET_CHECK(ShapeUtil::IsArray(data_shape)); + TF_RET_CHECK(data_shape.IsArray()); TF_RET_CHECK(LayoutUtil::HasLayout(data_shape)); const Layout* prev_layout = host_channel_constraints_.ConstrainChannel( send_recv_instr->channel_id(), data_shape.layout()); @@ -489,7 +486,7 @@ Status LayoutAssignment::AddMandatoryConstraints( if (instruction->opcode() == HloOpcode::kSend) { // TODO(b/68493863): Change to use SetOperandLayout(). const Shape send_buffer_shape = instruction->operand(0)->shape(); - TF_RET_CHECK(ShapeUtil::IsArray(send_buffer_shape)); + TF_RET_CHECK(send_buffer_shape.IsArray()); Shape new_buffer_shape = get_channel_constraints(instruction) ->LayoutShapeForChannel(send_buffer_shape, @@ -499,7 +496,7 @@ Status LayoutAssignment::AddMandatoryConstraints( } else { const Shape recv_buffer_shape = ShapeUtil::GetTupleElementShape(instruction->shape(), 0); - TF_RET_CHECK(ShapeUtil::IsArray(recv_buffer_shape)); + TF_RET_CHECK(recv_buffer_shape.IsArray()); TF_ASSIGN_OR_RETURN( const LogicalBuffer* buffer, constraints->points_to_analysis().GetBufferDefinedAt(instruction, @@ -520,7 +517,7 @@ Status LayoutAssignment::AddMandatoryConstraints( } // TODO(b/68493863): Change to use SetOperandLayout(). const Shape& buffer_shape = instruction->operand(0)->shape(); - TF_RET_CHECK(ShapeUtil::IsArray(buffer_shape)); + TF_RET_CHECK(buffer_shape.IsArray()); Shape new_buffer_shape = get_channel_constraints(instruction) ->LayoutShapeForChannel(buffer_shape, all_reduce_id); @@ -780,7 +777,7 @@ StatusOr LayoutAssignment::CreateCopyWithNewLayout( << ShapeUtil::HumanString(instruction->shape()) << " instruction: " << instruction->ToString(); - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { // Copy tuple elements which have differing layouts. std::vector element_copies; for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); @@ -811,7 +808,7 @@ StatusOr LayoutAssignment::CreateCopyWithNewLayout( TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( shape_with_layout, tuple_copy->mutable_shape())); return tuple_copy; - } else if (ShapeUtil::IsArray(instruction->shape())) { + } else if (instruction->shape().IsArray()) { HloInstruction* copy = instruction->parent()->AddInstruction(HloInstruction::CreateUnary( instruction->shape(), HloOpcode::kCopy, instruction)); @@ -988,11 +985,10 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( const Layout& output_layout, const HloInstruction* instruction, int64 operand_no) { const HloInstruction* operand = instruction->operand(operand_no); - CHECK(ShapeUtil::IsArray(instruction->shape())); - CHECK(ShapeUtil::IsArray(operand->shape())); + CHECK(instruction->shape().IsArray()); + CHECK(operand->shape().IsArray()); if (!ShapeUtil::IsScalar(operand->shape()) && - ShapeUtil::Rank(operand->shape()) == - ShapeUtil::Rank(instruction->shape()) && + operand->shape().rank() == instruction->shape().rank() && !instruction_can_change_layout_func_(instruction)) { // Propagate the result layout to the operand layout if the instruction // requires the same layout out for the result and the operand. @@ -1012,7 +1008,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( // operations. For similar reasons, if the operand and output have the same // rank, try to match the operand's layout to the output. if (ShapeUtil::TrueRank(operand->shape()) == 1 && - ShapeUtil::Rank(instruction->shape()) == 1) { + instruction->shape().rank() == 1) { // Don't assign a layout in case of R1 -> effective R1 reshape. return nullptr; } @@ -1026,7 +1022,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { return absl::make_unique(operand_shape.layout()); } - if (ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)) { + if (operand_shape.rank() == output_shape.rank()) { *operand_shape.mutable_layout() = output_layout; if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { @@ -1045,7 +1041,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( if (instruction->opcode() == HloOpcode::kTranspose) { // Pick the operand layout that makes the transpose a bitcast. - int64 rank = ShapeUtil::Rank(instruction->shape()); + int64 rank = instruction->shape().rank(); std::vector new_minor_to_major(rank); for (int64 i = 0; i < rank; ++i) { int64 output_dim = LayoutUtil::Minor(output_layout, i); @@ -1066,11 +1062,10 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( int64 operand_no) { const HloInstruction* operand = user->operand(operand_no); - CHECK(ShapeUtil::IsArray(user->shape()) && - ShapeUtil::IsArray(operand->shape())); + CHECK(user->shape().IsArray() && operand->shape().IsArray()); if (!ShapeUtil::IsScalar(operand->shape()) && - ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape()) && + operand->shape().rank() == user->shape().rank() && !instruction_can_change_layout_func_(user)) { // Assign users the same layout as the operand. return absl::make_unique(operand_layout); @@ -1083,7 +1078,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( // reshape is a bitcast when using the same layout. This may avoid copy // operations. For similar reasons, if the operand and output have the same // rank, try to match the outputs's layout to the operand. - if (ShapeUtil::Rank(operand->shape()) == 1 && + if (operand->shape().rank() == 1 && ShapeUtil::TrueRank(user->shape()) == 1) { // Don't assign a layout in case of R1 -> effective R1 reshape. return nullptr; @@ -1098,7 +1093,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { return absl::make_unique(output_shape.layout()); } - if (ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)) { + if (operand->shape().rank() == output_shape.rank()) { *output_shape.mutable_layout() = operand_layout; if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { @@ -1117,7 +1112,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( if (user->opcode() == HloOpcode::kTranspose) { // Pick the user layout that makes the transpose a bitcast. - int64 rank = ShapeUtil::Rank(user->shape()); + int64 rank = user->shape().rank(); std::vector new_minor_to_major(rank); auto inverse_dimensions = InversePermutation(user->dimensions()); for (int64 i = 0; i < rank; ++i) { @@ -1193,7 +1188,7 @@ std::vector> GetArrayUsesOfBuffer( CHECK(buffer.IsArray()); std::vector> uses; for (const auto& buffer_alias : points_to_analysis.GetBufferAliases(buffer)) { - if (!ShapeUtil::IsArray(buffer_alias.instruction()->shape())) { + if (!buffer_alias.instruction()->shape().IsArray()) { continue; } // This alias must be the top-level (index == {}) of the instruction's @@ -1227,7 +1222,7 @@ Status LayoutAssignment::PropagateUseConstraintToDefs( if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) { for (const LogicalBuffer* buffer : buffers) { if (constraints->BufferLayout(*buffer) == nullptr && - ShapeUtil::IsArray(buffer->shape())) { + buffer->shape().IsArray()) { TF_RETURN_IF_ERROR(constraints->SetBufferLayout( ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(), *buffer, /*mandatory=*/true)); @@ -1238,6 +1233,23 @@ Status LayoutAssignment::PropagateUseConstraintToDefs( }); } +namespace { +// A transpose or a reshape that only changes trivial dimensions have meaningful +// layouts that are valuable to propagate in a depthfirst manner to avoid +// unassigned layouts in the graph. +bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) { + switch (hlo.opcode()) { + case HloOpcode::kReshape: + return std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()); + case HloOpcode::kTranspose: + return true; + default: + return false; + } +} + +} // namespace + Status LayoutAssignment::PropagateOperandConstraint( const OperandLayoutConstraint& operand_constraint, LayoutConstraints* constraints) { @@ -1258,11 +1270,10 @@ Status LayoutAssignment::PropagateOperandConstraint( // layout for the operands with the same ranks. const HloInstruction* operand = operand_constraint.operand(); const HloInstruction* user = operand_constraint.instruction(); - if (!ShapeUtil::IsArray(operand->shape())) { + if (!operand->shape().IsArray()) { return Status::OK(); } - if (instruction_can_change_layout_func_(user) && - !ShapeUtil::IsArray(user->shape())) { + if (instruction_can_change_layout_func_(user) && !user->shape().IsArray()) { return Status::OK(); } @@ -1273,7 +1284,7 @@ Status LayoutAssignment::PropagateOperandConstraint( return Status::OK(); } - int64 operand_rank = ShapeUtil::Rank(operand->shape()); + int64 operand_rank = operand->shape().rank(); if (operand_rank <= 1) { return Status::OK(); } @@ -1288,7 +1299,7 @@ Status LayoutAssignment::PropagateOperandConstraint( continue; } const HloInstruction* sibling = user->operand(operand_no); - const int64 sibling_rank = ShapeUtil::Rank(sibling->shape()); + const int64 sibling_rank = sibling->shape().rank(); if (sibling_rank <= 1) { continue; } @@ -1317,16 +1328,16 @@ Status LayoutAssignment::PropagateOperandConstraint( TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsTuple(subshape)) { + if (subshape.IsTuple()) { return Status::OK(); } - if (ShapeUtil::Rank(subshape) <= 1) { + if (subshape.rank() <= 1) { return Status::OK(); } // Assign the right layout to input fusion of higher rank reduce // operations. - if (ShapeUtil::Rank(subshape) != ShapeUtil::Rank(operand->shape())) { + if (subshape.rank() != operand->shape().rank()) { return Status::OK(); } // TODO(b/67641796): Are there cases except fusion that use this code @@ -1354,10 +1365,10 @@ Status LayoutAssignment::PropagateOperandConstraint( } TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsTuple(subshape)) { + if (subshape.IsTuple()) { return Status::OK(); } - if (ShapeUtil::Rank(subshape) <= 1) { + if (subshape.rank() <= 1) { return Status::OK(); } TF_ASSIGN_OR_RETURN( @@ -1373,7 +1384,7 @@ Status LayoutAssignment::PropagateOperandConstraint( TF_RETURN_IF_ERROR(constraints->SetBufferLayout( *layout, *buffer, /*mandatory=*/user->opcode() == HloOpcode::kReduce, - /*dfs=*/false)); + /*dfs=*/InstructionShouldPropagateDepthFirst(*user))); } } return Status::OK(); @@ -1401,8 +1412,8 @@ Status LayoutAssignment::PropagateBufferConstraintToOperands( } if (!instruction_can_change_layout_func_(instruction)) { // Copy the layout to the operand. - if (buffer.IsArray() && ShapeUtil::IsArray(operand->shape()) && - ShapeUtil::Rank(operand->shape()) == + if (buffer.IsArray() && operand->shape().IsArray() && + operand->shape().rank() == LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) { TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( buffer_constraint.layout(), instruction, operand_no, @@ -1410,7 +1421,7 @@ Status LayoutAssignment::PropagateBufferConstraintToOperands( } } else { if (!buffer.IsTopLevel() || - !ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) { + !instruction->operand(operand_no)->shape().IsArray()) { continue; // Don't touch buffers that are internal to a tuple. } VLOG(6) << "Propagating constraint to operand " << operand_no << " of " @@ -1423,11 +1434,9 @@ Status LayoutAssignment::PropagateBufferConstraintToOperands( ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(), instruction, operand_no); if (operand_layout != nullptr) { - // Do not propagate operand constraints of transposes and reshapes, it - // tends to create really bad layouts. TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( *operand_layout, instruction, operand_no, /*mandatory=*/false, - /*dfs=*/false)); + /*dfs=*/InstructionShouldPropagateDepthFirst(*instruction))); } } else { VLOG(6) << "Operand already has a constraint " @@ -1497,7 +1506,7 @@ StatusOr InferArrayLayout( // This function should only be called for array shapes which don't yet have // layouts. const Shape& subshape = ShapeUtil::GetSubshape(instruction->shape(), index); - TF_RET_CHECK(ShapeUtil::IsArray(subshape)); + TF_RET_CHECK(subshape.IsArray()); TF_RET_CHECK(!subshape.has_layout()); // The instruction should not define the buffer at this index. @@ -1576,8 +1585,9 @@ Status SetFusionLayouts(HloInstruction* fusion) { fused_instruction->mutable_shape())); } else if (fused_instruction->opcode() == HloOpcode::kInfeed) { // Nop; leave the infeed layout alone. - } else { + } else if (fusion->fusion_kind() != HloInstruction::FusionKind::kCustom) { // Other instructions don't have layouts inside of fusion nodes. + // But do not clear layouts for other instructions in custom fusion nodes. LayoutUtil::ClearLayout(fused_instruction->mutable_shape()); } } @@ -1615,7 +1625,7 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, for (const LogicalBuffer* buffer : constraints.points_to_analysis().GetBuffersDefinedByInstruction( instruction)) { - if (!ShapeUtil::IsArray(buffer->shape())) { + if (!buffer->shape().IsArray()) { continue; } @@ -1639,7 +1649,7 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( instruction->mutable_shape(), [instruction, &constraints](Shape* subshape, const ShapeIndex& index) { - if (subshape->has_layout() || !ShapeUtil::IsArray(*subshape)) { + if (subshape->has_layout() || !subshape->IsArray()) { return Status::OK(); } // Set Layout of subshape to match layout of LogicalBuffer which @@ -2100,8 +2110,8 @@ bool LayoutAssignment::InstructionCanChangeLayout( /* static */ bool LayoutAssignment::IsAtMostRank1(const Shape& shape) { - if (ShapeUtil::IsArray(shape)) { - return ShapeUtil::Rank(shape) <= 1; + if (shape.IsArray()) { + return shape.rank() <= 1; } return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) { return IsAtMostRank1(subshape); @@ -2123,7 +2133,7 @@ Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) { for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { if (instruction->opcode() == HloOpcode::kCopy && - added_copies_.count(instruction) > 0) { + added_copies_.contains(instruction)) { VLOG(5) << "Removing added copy: " << instruction->ToString(); TF_RETURN_IF_ERROR( instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 3b081de3c7826c3c11a7d87d542835d0ecce1b7e..5701cb5b025e563247d46d0d24f81a5f886fc23b 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -243,7 +243,7 @@ class ChannelLayoutConstraints { // Returns true if channel_id has a layout constraint. bool IsChannelConstrained(int64 channel_id) const { - return constraints_.count(channel_id) > 0; + return constraints_.contains(channel_id); } // Given `shape`, apply the layout for `channel_id`. `channel_id` must already @@ -276,7 +276,7 @@ class ChannelLayoutConstraints { } private: - std::unordered_map constraints_; + absl::flat_hash_map constraints_; }; // HLO pass which assigns layouts to all instructions in the HLO module while diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 31d78752f07c57aef6023fabb8e3a7de20c4278c..c8cf3c47d380012fdb0206c0d20d67e6a13017ae 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -528,8 +528,7 @@ class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment { for (int64 operand_no = 0; operand_no < instruction->operand_count(); ++operand_no) { const HloInstruction* operand = instruction->operand(operand_no); - if (ShapeUtil::Rank(instruction->shape()) != - ShapeUtil::Rank(operand->shape())) { + if (instruction->shape().rank() != operand->shape().rank()) { continue; } TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( @@ -961,8 +960,9 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange { par0 = f32[3,4]{1,0} parameter(0) par1 = f32[4,5]{0,1} parameter(1) - par2 = s32[2] parameter(2) - dslice0 = f32[3,4] dynamic-slice(par1, par2), dynamic_slice_sizes={3,4} + par2 = s32[] parameter(2) + par3 = s32[] parameter(3) + dslice0 = f32[3,4] dynamic-slice(par1, par2, par3), dynamic_slice_sizes={3,4} ROOT add0 = f32[3,4]{1,0} add(par0,dslice0) } )"; @@ -983,7 +983,7 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { m::Parameter(), m::DynamicSlice( m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy), - m::Parameter(2))))); + m::Parameter(2), m::Parameter(3))))); } TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 728a66b388f0f9af480ff88b5e96990a26e36af5..c5d59fb28e02ce229967fb3856012d608fb83c5d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -39,7 +39,6 @@ cc_library( "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@llvm//:core", ], @@ -169,6 +168,7 @@ cc_library( "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@llvm//:core", diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index 643ecd0fbaa546c551097b29e74ccd49418e1466..ce3d922ca7a9bdea3a520959a8b8d284bc3e0d64 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -81,9 +81,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, if (hlo.opcode() == HloOpcode::kParameter) { const std::vector& parameter_instructions = module_.entry_computation()->parameter_instructions(); - if (std::find(parameter_instructions.begin(), - parameter_instructions.end(), - &hlo) != parameter_instructions.end()) { + if (absl::c_linear_search(parameter_instructions, &hlo)) { array->MarkInvariantOverWholeProgram(context_); } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index 2b46b3c3964b15548dbacc8b0ada0047a0fa85b6..12e2f449e23ac2511aac576fed893f5a9ef510c0 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -76,15 +76,12 @@ class AliasAnalysis { // A map from a buffer slice to metadata corresponding to its alias.scope // metadata. The index kParameterAliasSet is used to hold aliasing // information for parameters. - absl::flat_hash_map + absl::flat_hash_map alias_scope_metadata_; // A map from a buffer slice to metadata corresponding to its noalias // metadata. - absl::flat_hash_map - noalias_metadata_; + absl::flat_hash_map noalias_metadata_; }; } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc index bdce4a171b8a58f617f1d56e6cf6db5354846703..c2c6405cdad28196a4793887c8c5cc5b87ee5301 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -45,7 +45,10 @@ string ConstantBufferAllocationToGlobalName( const BufferAllocation& allocation) { string instr_name = InstrForConstantBufferAllocation(allocation).name(); for (char& c : instr_name) { - if (c == '.') { + // Having a hyphen in a global variable name can crash the LLVM PTX backend. + // LLVM is able to generate unique global variable names using the string + // returned from here as name prefix. + if (c == '.' || c == '-') { c = '_'; } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index 4d7f36d9f8b565a819edf0631efc5c7a58c4f87f..c66eaec8fb0e4c03f6967fec0cf0ae9661cdf470 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -36,19 +36,20 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, // EmitFusedDynamicUpdateSliceInPlace. // // Emits a sequential loop if launch_dimensions is null. +using IndexGenerator = std::function(int64)>; + static Status EmitDynamicUpdateSliceInPlaceImpl( - const Shape& update_shape, const ElementGenerator& start_indices_generator, + const Shape& update_shape, const IndexGenerator& start_indices_generator, bool is_signed, ElementGenerator update_array_generator, const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions, absl::string_view name, llvm::IRBuilder<>* b) { const Shape& output_shape = output_array.GetShape(); // Read start indices from start_indices_generator. - const int64 rank = ShapeUtil::Rank(output_shape); + const int64 rank = output_shape.rank(); IrArray::Index start_index(b->getInt64Ty(), rank); for (int64 i = 0; i < rank; ++i) { - IrArray::Index dim_index({b->getInt64(i)}); - TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index)); + TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(i)); llvm::Value* output_dim_size = llvm::ConstantInt::get( start_index[i]->getType(), output_shape.dimensions(i)); llvm::Value* update_dim_size = llvm::ConstantInt::get( @@ -112,9 +113,20 @@ Status EmitDynamicUpdateSliceInPlace(absl::Span operand_arrays, Shape output_shape = output_array.GetShape(); Shape update_shape = update_array.GetShape(); - ElementGenerator start_indices_generator = [&](const IrArray::Index& index) { - return start_indices_array.EmitReadArrayElement(index, b); - }; + IndexGenerator start_indices_generator; + // TODO(b/118437727): Remove the R1 path, and rename the variables. + if (start_indices_array.GetShape().rank() == 1) { + start_indices_generator = [&](int64 index) { + return start_indices_array.EmitReadArrayElement( + IrArray::Index({b->getInt64(index)}), b); + }; + } else { + start_indices_generator = [&](int64 index) { + return operand_arrays[2 + index].EmitReadArrayElement( + IrArray::Index(b->getInt64Ty()), b); + }; + } + ElementGenerator update_array_generator = [&](const IrArray::Index& index) { return update_array.EmitReadArrayElement(index, b); }; @@ -165,8 +177,21 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( elemental_emitter); TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter)); ElementGenerator update_array_generator = fused_emitter.GetGenerator(update); - ElementGenerator start_indices_generator = - fused_emitter.GetGenerator(start_indices); + + // TODO(b/118437727): Remove the R1 path, and rename the variables. + IndexGenerator start_indices_generator; + if (start_indices->shape().rank() == 1) { + start_indices_generator = [&](int64 index) { + return fused_emitter.GetGenerator(start_indices)( + IrArray::Index({b->getInt64(index)})); + }; + } else { + start_indices_generator = [&](int64 index) { + ElementGenerator element_generator = + fused_emitter.GetGenerator(dynamic_update_slice->operand(2 + index)); + return element_generator(IrArray::Index(b->getInt64Ty())); + }; + } bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape()); return EmitDynamicUpdateSliceInPlaceImpl( diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 38f2b5da23a7b92e4547dceaba011ce654977da3..e440f05e2b2f0d4a2a4c7b326b4881183de4d235 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -35,7 +35,7 @@ using llvm_ir::IrArray; Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) { indexed_generators_[hlo] = [=](const IrArray::Index& index) -> StatusOr { - if (generated_value_cache_[hlo].count(index.multidim()) > 0) { + if (generated_value_cache_[hlo].contains(index.multidim())) { llvm::Value* generated_value = generated_value_cache_[hlo][index.multidim()]; llvm::BasicBlock* generated_value_bb = nullptr; @@ -115,7 +115,7 @@ Status FusedIrEmitter::HandleGetTupleElement( /*alignment=*/1, tuple_ptr, b_, module_); }; - if (!ShapeUtil::IsTuple(get_tuple_element->shape())) { + if (!get_tuple_element->shape().IsTuple()) { indexed_generators_[get_tuple_element] = [=](const IrArray::Index& index) -> StatusOr { // TODO(b/34080002) Add aliasing information to tuple element IrArray. diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index 1b9c61f6700e2a1309b21e499f4a9e2439ed3702..e6d52a580c04a920d3f0e8ed6f39c1cae587cf1b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" @@ -134,8 +135,9 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { // Cache of generated values, lest we regenerate an element of a node with // multiple outgoing edges - std::unordered_map, llvm::Value*>> + absl::flat_hash_map< + const HloInstruction*, + absl::flat_hash_map, llvm::Value*>> generated_value_cache_; }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 67f7423121177e2ca1e3384341dad2644c8f5e34..8ee07ae8331e986f9d271be5e39065f0d87853b1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -61,7 +61,7 @@ void IrArray::Index::Delinearize(std::vector* multidim, IrArray::Index::Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b) - : multidim_(ShapeUtil::Rank(shape)), + : multidim_(shape.rank()), linear_(linear), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { @@ -104,8 +104,8 @@ IrArray::Index::Index(absl::Span multidim, CHECK(LayoutUtil::HasLayout(shape)); } -IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape) - : base_ptr_(base_ptr), shape_(&shape) { +IrArray::IrArray(llvm::Value* base_ptr, Shape shape) + : base_ptr_(base_ptr), shape_(std::move(shape)) { TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); CHECK(base_ptr_->getType()->isPointerTy()); int depth = 0; @@ -117,10 +117,10 @@ IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape) ++depth; } - if (!ShapeUtil::IsArray(*shape_) || ShapeUtil::IsScalar(*shape_)) { + if (!shape_->IsArray() || ShapeUtil::IsScalar(*shape_)) { DCHECK(depth == 1 || depth == 0) << depth; } else { - DCHECK_EQ(depth, ShapeUtil::Rank(*shape_)) << shape.ShortDebugString(); + DCHECK_EQ(depth, shape_->rank()) << shape.ShortDebugString(); } } @@ -137,12 +137,12 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( const Shape& output_shape, const Shape& input_shape, llvm::IRBuilder<>* builder) const { const auto& target_index = *this; - CHECK_EQ(target_index.size(), ShapeUtil::Rank(output_shape)); + CHECK_EQ(target_index.size(), output_shape.rank()); std::vector> common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()), AsInt64Slice(output_shape.dimensions())); std::vector source_multidim_index( - ShapeUtil::Rank(input_shape), llvm::UndefValue::get(index_type_)); + input_shape.rank(), llvm::UndefValue::get(index_type_)); // We compute the source indices in each common factor from only the target // indices in the same common factor. for (ssize_t k = common_factors.size() - 2; k >= 0; --k) { @@ -257,7 +257,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( const Shape& shape, const Shape& operand_shape, absl::Span dimension_mapping, llvm::IRBuilder<>* builder) const { - int64 rank = ShapeUtil::Rank(operand_shape); + int64 rank = operand_shape.rank(); std::vector source_index(rank); for (int64 i = 0; i < rank; ++i) { source_index[i] = multidim_[dimension_mapping[i]]; @@ -271,7 +271,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( // The other dimensions can be masked out with a div and a mod operation. std::vector logical_to_physical = LayoutUtil::MakeLogicalToPhysical(shape.layout()); - int64 output_rank = ShapeUtil::Rank(shape); + int64 output_rank = shape.rank(); // The minimum physical dimension that is broadcasted. int64 min_broadcasted_dimension = output_rank; // The maximum physical dimension that is broadcasted. @@ -348,7 +348,7 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, // over higher-rank arrays. return base_ptr_; } - CHECK_EQ(index.size(), ShapeUtil::Rank(*shape_)); + CHECK_EQ(index.size(), shape_->rank()); if (index.LinearValidOnShape(*shape_)) { llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index d6d84994ee147f4b8c1a333b0eaccdf6e0a2219b..b706ebd311cbb706e7e4698b93319e37e664d10a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -130,6 +130,11 @@ class IrArray { CHECK_LE(index, size()); mutable_multidim().insert(mutable_multidim().begin() + index, value); } + void InsertAt(int64 index, int64 count, llvm::Value* value) { + CHECK_LE(index, size()); + mutable_multidim().insert(mutable_multidim().begin() + index, count, + value); + } using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; @@ -189,6 +194,8 @@ class IrArray { return llvm::ConstantInt::get(index_type_, c); } + void ClearLinearIndex() { linear_ = nullptr; } + private: // Changing the multi-dimensional index invalidates the linear index. std::vector& mutable_multidim() { @@ -220,11 +227,11 @@ class IrArray { }; // Default constructor. Constructs an IrArray in a null status. - IrArray() : base_ptr_(nullptr), shape_(nullptr) {} + IrArray() : base_ptr_(nullptr) {} // Construct an IrArray with the given base pointer and shape. base_ptr is a // pointer type pointing to the first element(lowest address) of the array. - IrArray(llvm::Value* base_ptr, const Shape& shape); + IrArray(llvm::Value* base_ptr, Shape shape); // Default implementations of copying and moving. IrArray(IrArray&& other) = default; @@ -236,7 +243,6 @@ class IrArray { llvm::Type* GetElementLlvmType() const { return element_type_; } const Shape& GetShape() const { - CHECK(shape_ != nullptr); return *shape_; } @@ -331,7 +337,7 @@ class IrArray { llvm::Type* element_type_; // Shape of the XLA array. - const Shape* shape_; + absl::optional shape_; // The list of key/value pairs used when attaching metadata to emitted // loads/stores for this array. They keys are the metadata kinds and the diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h index abc06fb7b4245294df2dc20d25a22ac4fdaeb4cf..cf5083e8c13b9485035923895cec1ad05049c644 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h @@ -254,6 +254,11 @@ class IrBuilderMixin { return mixin_builder()->CreateFCmpOLT(std::forward(args)...); } + template + llvm::Value* FCmpOLE(Args&&... args) { + return mixin_builder()->CreateFCmpOLE(std::forward(args)...); + } + template llvm::Value* FCmpONE(Args&&... args) { return mixin_builder()->CreateFCmpONE(std::forward(args)...); diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index cebbc4290163d4e98003cd7b5df6ec906509a446..cd8dd72cd775d5e0b52f96a2326367da0775e7eb 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -123,7 +123,8 @@ KernelMappingScheme::KernelMappingScheme( dims_in_elems_(dims_in_elems.begin(), dims_in_elems.end()), tile_sizes_{1, tile_size_y, tile_size_x}, num_threads_x_(num_threads_x), - num_threads_y_(num_threads_y) { + num_threads_y_(num_threads_y), + dilated_x_(true) { DCHECK_EQ(dims_in_elems_.size(), 3); DCHECK_EQ(req_block_sizes.size(), 3); diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h index fb633b12e60d1a9f3103fb2919ad2c3f3f14de20..f802cc27d519e621262f328903697373aa8c284c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -117,7 +117,10 @@ class KernelMappingScheme { int64 GetNumberOfTilesInOneBlock() const { return absl::c_accumulate(block_sizes_, 1, std::multiplies()); } - + int64 GetNumberOfTilesInOneBlockForDimension(int d) const { + DCHECK(d >= DimZ && d <= DimX); + return block_sizes_[d]; + } int64 GetNumberOfBlocks() const { return absl::c_accumulate(dims_in_blocks_, 1, std::multiplies()); } @@ -147,6 +150,16 @@ class KernelMappingScheme { GetNumberOfThreadsForDimensionY(); } + bool DilatedX() const { return dilated_x_; } + void SetDilatedX(bool v) { + dilated_x_ = v; + if (!dilated_x_) { + // dilated_x_=false is for the purpose of vectorization, which requires + // GetTileSizeForDimension(DimX) to be a multiplier of num_threads_x_. + CHECK_EQ(GetTileSizeForDimension(DimX) % num_threads_x_, 0); + } + } + IrArray::Index EmitBlockIndex(llvm::Type* index_ty); // Returns the index for the first tile in the block with the given block // index. @@ -186,6 +199,13 @@ class KernelMappingScheme { int64 num_threads_x_; // Number of threads used to process elements in the Y direction of a tile. int64 num_threads_y_; + + // When num_threads_x threads process a total of tile_size_x elements in the + // X dimension of a tile, each threads process n=tile_size_x/num_threads_x + // elements. When dilated_x=false, the n elements processed by a thread are + // contiguous. On the other hand, when dilated_x=true the n elements are + // dilated by a factor of num_threads_x. + bool dilated_x_; }; // A class to represent information for tiled parameters to support IR emission diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 219a9f221fbd116cdfbaf17985e21d82aefd079d..fe320bbe727111fbc986cc1fbc217feed74d30f1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -235,7 +235,7 @@ std::unique_ptr ForLoopNest::AddLoop(int64 start_index, IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, absl::string_view suffix) { - std::vector dimensions(ShapeUtil::Rank(shape)); + std::vector dimensions(shape.rank()); std::iota(dimensions.begin(), dimensions.end(), 0); return AddLoopsForShapeOnDimensions(shape, dimensions, suffix); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index ceea24685af566e02340664f0a40c398c62b5ab0..807296329c07b8e4ac630486a1e1f59e4fdfa009 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -188,7 +188,16 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, } return cplx_t; } - // A Tuple contains an array of pointers. Use i8*. + case C128: { + auto cplx_t = module->getTypeByName("complex128"); + if (cplx_t == nullptr) { + return llvm::StructType::create( + {llvm::Type::getDoubleTy(module->getContext()), + llvm::Type::getDoubleTy(module->getContext())}, + "complex128", /*isPacked=*/true); + } + return cplx_t; + } // A Tuple contains an array of pointers. Use i8*. case TUPLE: // An Opaque is like a void*, use i8*. case OPAQUE: @@ -219,10 +228,10 @@ int GetSizeInBits(llvm::Type* type) { llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) { llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module); - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { // A tuple buffer is an array of pointers. result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size()); - } else if (ShapeUtil::IsArray(shape)) { + } else if (shape.IsArray()) { for (int64 dimension : LayoutUtil::MinorToMajor(shape)) { result_type = llvm::ArrayType::get(result_type, shape.dimensions(dimension)); @@ -621,6 +630,10 @@ llvm::Function* CreateFunction(llvm::FunctionType* function_type, function->setCallingConv(llvm::CallingConv::C); function->addFnAttr("no-frame-pointer-elim", "false"); + // Generate unwind information so that GDB can crawl through the stack frames + // created by the JIT compiled code. + function->setHasUWTable(); + if (enable_fast_math) { function->addFnAttr("unsafe-fp-math", "true"); function->addFnAttr("no-infs-fp-math", "true"); diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index 6a9406bfebafcc02dc2e144b62284a9e83c3edeb..89b6a36f96beedbcb7322e6164ac59221650d3d8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -322,7 +322,7 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, // comparisons). const Shape& keys_shape = keys_array.GetShape(); - int64 rank = ShapeUtil::Rank(keys_shape); + int64 rank = keys_shape.rank(); int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); std::vector dimensions_in_iteration_order(rank); std::vector iteration_order_to_logical_order(rank); diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index a60643bc754f896d096b3ca4e1216e77d7e384c6..d8d2700e1934fd202d44a1dc60e71a99913d4537 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -93,7 +93,7 @@ llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, llvm::LoadInst* src_buffer = b->CreateLoad(element_ptr); // Mark the loaded pointer as dereferenceable if we know its shape. - if (!ShapeUtil::IsOpaque(target_shape)) { + if (!target_shape.IsOpaque()) { SetDereferenceableMetadataForLoad( src_buffer, ByteSizeOf(target_shape, src_buffer->getModule()->getDataLayout())); diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 600b069ecdbabf6b05e6abb3a6b8d9b1a4b0ecf4..3470fe5b2c34bf832207ed546fad176319446f31 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -110,6 +110,7 @@ ExecutionOptions CreateExecutionOptions( *execution_options.mutable_shape_with_output_layout() = result_shape.ToProto(); } + execution_options.set_num_replicas(build_options.num_replicas()); return execution_options; } diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 9ccdd7d8d818b9fa3aa77cdd10d37ca18928b448..53d52d9a3d918fa6dee093668923fcfff963d084 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -198,7 +198,7 @@ void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { if (instr == fusion || is_fused(instr) || is_connected(fusion, instr)) { continue; } - if (in_list.count(instr) > 0) { + if (in_list.contains(instr)) { continue; } int64 profit = GetProfit(instr, fusion); diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index daa718879ddd45afb02725b557380b2f49fe833e..e55b83d17e90bc2ca0053a0421cf80ef6edd5bca 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "absl/strings/ascii.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -28,13 +29,13 @@ namespace { bool IsAllowed(char character) { auto c = static_cast(character); - return (isalnum(c) != 0) || c == '_' || c == '.' || c == '-'; + return (absl::ascii_isalnum(c) != 0) || c == '_' || c == '.' || c == '-'; } } // namespace NameUniquer::NameUniquer(const string& separator) { - CHECK(std::all_of(separator.begin(), separator.end(), IsAllowed)) + CHECK(absl::c_all_of(separator, IsAllowed)) << "separator should comprises allowed characters only"; separator_ = separator; } @@ -46,7 +47,7 @@ NameUniquer::NameUniquer(const string& separator) { string result = name; char c = static_cast(result[0]); - if (!isalpha(c) && c != '_') { + if (!absl::ascii_isalpha(c) && c != '_') { result[0] = '_'; } for (int i = 1; i < result.length(); i++) { diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index fdb6a9b01be4b9198e40aa9bf7cdc07ff068a619..9e3d1060210790f60243195a1c1dff13f1fc7fc5 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -775,7 +775,7 @@ class ShapePatternIsArrayImpl { explicit constexpr ShapePatternIsArrayImpl() {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - if (!ShapeUtil::IsArray(*shape)) { + if (!shape->IsArray()) { EXPLAIN << "Shape is not an array"; return false; } @@ -793,7 +793,7 @@ class ShapePatternIsTupleImpl { explicit constexpr ShapePatternIsTupleImpl() {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - if (!ShapeUtil::IsTuple(*shape)) { + if (!shape->IsTuple()) { EXPLAIN << "Shape is not a tuple"; return false; } @@ -831,7 +831,7 @@ class ShapePatternRankImpl { explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - if (ShapeUtil::Rank(*shape) != rank_) { + if (shape->rank() != rank_) { if (rank_ == 0) { EXPLAIN << "Shape is not a scalar"; } else { @@ -1878,7 +1878,7 @@ class HloInstructionPattern { // Make this a templated function to work around gcc 4.9.4 template infinite // recursion bug. template - constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) + constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const -> decltype(this->WithShape(Shape().EqualTo(shape))) { return WithShape(Shape().EqualTo(shape)); } @@ -1886,7 +1886,7 @@ class HloInstructionPattern { // Make this a templated function to work around gcc 4.9.4 template infinite // recursion bug. template - constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) + constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const -> decltype(this->WithShape(Shape().CompatibleTo(shape))) { return WithShape(Shape().CompatibleTo(shape)); } @@ -2057,7 +2057,6 @@ XLA_UNOP_PATTERN(SendDone) XLA_UNOP_PATTERN(Sign) XLA_UNOP_PATTERN(Sin) XLA_UNOP_PATTERN(Slice) -XLA_UNOP_PATTERN(Sort) XLA_UNOP_PATTERN(Tanh) XLA_UNOP_PATTERN(Transpose) #undef XLA_UNOP_PATTERN @@ -2119,7 +2118,6 @@ XLA_BINOP_PATTERN(Divide) XLA_BINOP_PATTERN(Complex) XLA_BINOP_PATTERN(Convolution) XLA_BINOP_PATTERN(Dot) -XLA_BINOP_PATTERN(DynamicSlice) XLA_COMMUTATIVE_BINOP_PATTERN(Eq) XLA_BINOP_PATTERN(Gather) XLA_BINOP_PATTERN(Ge) @@ -2236,8 +2234,10 @@ inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg, XLA_VARIADIC_OP_PATTERN(AfterAll); XLA_VARIADIC_OP_PATTERN(Concatenate); XLA_VARIADIC_OP_PATTERN(CustomCall); +XLA_VARIADIC_OP_PATTERN(DynamicSlice) XLA_VARIADIC_OP_PATTERN(Map) XLA_VARIADIC_OP_PATTERN(Reduce); +XLA_VARIADIC_OP_PATTERN(Sort); XLA_VARIADIC_OP_PATTERN(Tuple); // Helpers for matching non-constant instructions. diff --git a/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc index 9ca2fb05c1f7ef093c58237cf21fbc7c813a592a..f51a18b13894d75300c46835fabd82a4ce0699af 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc @@ -23,7 +23,6 @@ namespace xla { namespace { namespace m = ::xla::match; -using ::testing::Eq; using ::testing::Not; template diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 896b73cda41cb21b539b586aa4701c5bad43f8b9..886a0545624927fa77528141f61d8ecb6bec180a 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -70,6 +70,9 @@ PlatformUtil::GetSupportedPlatforms() { for (se::Platform* platform : all_platforms) { auto compiler_status = Compiler::GetForPlatform(platform); if (compiler_status.ok()) { + if (!platform->Initialized()) { + TF_RETURN_IF_ERROR(platform->Initialize({})); + } platforms.push_back(platform); } else { LOG(INFO) << "platform " << platform->Name() << " present but no " @@ -260,8 +263,8 @@ PlatformUtil::GetStreamExecutors( // Block here in thread_pool destructor until all devices are initialized. } VLOG(1) << "Device initialization complete"; - if (std::all_of(stream_executors.begin(), stream_executors.end(), - [](se::StreamExecutor* s) { return s == nullptr; })) { + if (absl::c_all_of(stream_executors, + [](se::StreamExecutor* s) { return s == nullptr; })) { return InternalError("no supported devices found for platform %s", platform->Name()); } diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 4df746fca9f8320eed72911726f33bb01f06fed5..a62118df157edf67114ff41befbdce3da129fe93 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -226,7 +226,10 @@ StatusOr PerformSinkReshapeOrTranspose( // changes, so all the fused instructions have the same dimensions. for (const auto& fused_instruction : instruction->fused_instructions()) { Shape* shape = fused_instruction->mutable_shape(); - *shape->mutable_dimensions() = new_operand_shape.dimensions(); + shape->clear_dimensions(); + for (int64 i : new_operand_shape.dimensions()) { + shape->add_dimensions(i); + } *shape->mutable_layout() = new_operand_shape.layout(); } } diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index 11c2f8392d285095816dd5d61f7029c1bfd158d4..acad871c4d427b174ffce3a462a0a3918a1e0c33 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -26,7 +26,6 @@ limitations under the License. namespace xla { - // Transposes the given scatter_indices such that the index_vector_dim becomes // the most-minor dimension. static StatusOr TransposeIndexVectorDimToLast( @@ -60,6 +59,13 @@ static StatusOr CanonicalizeScatterIndices( TF_ASSIGN_OR_RETURN( HloInstruction * transposed_scatter_indices, TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim)); + if (scatter_indices->shape().rank() == index_vector_dim + 1 && + scatter_indices->shape().dimensions(index_vector_dim) == 1) { + auto new_shape = + ShapeUtil::DeleteDimension(index_vector_dim, scatter_indices->shape()); + TF_ASSIGN_OR_RETURN(scatter_indices, + MakeReshapeHlo(new_shape, scatter_indices)); + } bool indices_are_scalar = index_vector_dim == scatter_indices->shape().dimensions_size(); @@ -88,7 +94,7 @@ static StatusOr CanonicalizeScatterIndices( static StatusOr PermuteScatterAndWindowDims( HloInstruction* updates, absl::Span update_window_dims) { std::vector permutation; - const int64 updates_rank = ShapeUtil::Rank(updates->shape()); + const int64 updates_rank = updates->shape().rank(); permutation.reserve(updates_rank); for (int64 i = 0; i < updates_rank; ++i) { @@ -165,10 +171,9 @@ static StatusOr CheckIndexValidity( // Valid range for the index: [0, operand_dims - window_sizes] // Check if the index has any negative values. - TF_ASSIGN_OR_RETURN( - HloInstruction * zero_index, + HloInstruction* zero_index = BroadcastZeros(computation, index->shape().element_type(), - AsInt64Slice(index->shape().dimensions()))); + AsInt64Slice(index->shape().dimensions())); TF_ASSIGN_OR_RETURN(HloInstruction * negative_index_check, MakeBinaryHlo(HloOpcode::kLe, zero_index, index)); @@ -214,15 +219,11 @@ static StatusOr> ScatterLoopBody( HloInstruction* updates = loop_state[2]; bool has_scalar_indices = scatter_indices->shape().dimensions_size() == 1; - CHECK_EQ(has_scalar_indices, - dim_numbers.index_vector_dim() == - scatter->operand(1)->shape().dimensions_size()); // Build a vector form of the induction variable of the while loop. - TF_ASSIGN_OR_RETURN( - HloInstruction * induction_var_as_vector, + HloInstruction* induction_var_as_vector = MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, - /*result_shape_bounds=*/{1})); + /*result_shape_bounds=*/{1}); // Pick the index to scatter from scatter_indices based on the induction_var // and transform that to an index into the `operand` space. diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index a0126f39b3dc4281abedc36a19dd20c3b128e249..83434528a21b16cad7c831e7d9cc42d436634540 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" @@ -295,11 +296,16 @@ StatusOr> Service::CreateModuleConfig( computation_layout->mutable_result_layout()->SetToDefaultLayout(); } - config->set_replica_count(options_.number_of_replicas()); if (execution_options != nullptr) { + if (execution_options->num_replicas() > 0) { + config->set_replica_count(execution_options->num_replicas()); + } else { + config->set_replica_count(options_.number_of_replicas()); + } config->set_seed(execution_options->seed()); config->set_debug_options(execution_options->debug_options()); } else { + config->set_replica_count(options_.number_of_replicas()); config->set_debug_options(GetDebugOptionsFromFlags()); } @@ -523,13 +529,13 @@ Service::ExecuteParallelAndRegisterResult( StatusOr Service::ExecuteAndRegisterResult( Executable* executable, - const absl::Span> arguments, - Backend* backend, const string& result_tag, ExecutionProfile* profile) { + absl::Span> arguments, + Backend* backend, const DeviceHandle& device_handle, + const string& result_tag, ExecutionProfile* profile) { // Set up streams. std::vector streams; - TF_ASSIGN_OR_RETURN(auto replicas, - Replicas(*backend, SingleComputationDeviceHandle())); + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handle)); TF_RET_CHECK(!replicas.empty()); for (se::StreamExecutor* executor : replicas) { TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream, @@ -537,10 +543,11 @@ StatusOr Service::ExecuteAndRegisterResult( streams.push_back(std::move(stream)); } - TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, - backend->computation_placer()->AssignDevices( - options_.number_of_replicas(), - /*computation_count=*/1)); + DeviceAssignment device_assignment(options_.number_of_replicas(), + /*computation_count=*/1); + for (int64 replica = 0; replica < replicas.size(); ++replica) { + device_assignment(replica, 0) = replicas[replica]->device_ordinal(); + } // Set up run options. std::vector run_options; @@ -552,9 +559,7 @@ StatusOr Service::ExecuteAndRegisterResult( options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); options.set_device_assignment(&device_assignment); - run_options.emplace_back( - options, backend->StreamBorrower(), - /*xla_intra_op_thread_pool=*/backend->eigen_intra_op_thread_pool()); + run_options.emplace_back(options, backend->StreamBorrower()); } if (options_.number_of_replicas() == 1) { @@ -711,14 +716,33 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, } } - // Execute the generated executables in parallel and return the device - // handles for each computation's output. + // If we have multiple executables to run, execute them all in parallel. But + // if we only have one executable, execute it using the vanilla, non-parallel + // call. + // + // We do this because the Client API uses ExecuteGraphParallel when it wants + // to compile and run one computation without caching the executable, but not + // all backends support the async StreamExecutor API required by + // ExecuteParallelAndRegisterResult. + // + // TODO(b/122731460): Consolidate Execute{,Parallel}AndRegisterResult; they do + // basically the same thing. ExecutionProfile profile; - TF_ASSIGN_OR_RETURN( - std::vector outputs, - ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, - execute_backend_.get(), device_handles, - computation_names, &profile)); + std::vector outputs; + if (executable_ptrs.size() == 1) { + TF_ASSIGN_OR_RETURN( + auto output, + ExecuteAndRegisterResult(executable_ptrs[0], all_arguments[0], + execute_backend_.get(), device_handles[0], + computation_names[0], &profile)); + outputs.push_back(std::move(output)); + } else { + TF_ASSIGN_OR_RETURN( + outputs, ExecuteParallelAndRegisterResult( + executable_ptrs, all_arguments, execute_backend_.get(), + device_handles, computation_names, &profile)); + } + for (const GlobalDataHandle& output : outputs) { ExecuteResponse response; *response.mutable_output() = output; @@ -904,6 +928,7 @@ Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { *result->mutable_output(), ExecuteAndRegisterResult(executable.get(), replicated_arguments, execute_backend_.get(), + SingleComputationDeviceHandle(), "result of " + executable->module().name(), result->mutable_profile())); @@ -1097,9 +1122,12 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, TF_ASSIGN_OR_RETURN(std::unique_ptr module, CreateModuleFromProto(arg->computation(), config)); + TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference, + DynamicDimensionInference::Run(module.get())); + HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate( - *module, /*arg_literals=*/{})); + evaluator.set_dynamic_dimension_inference(&dynamic_dimension_inference); + TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {})); // Since the result layout is non-effective to the Evaluator results, explicit // relayout here. diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index abd3ee5a059ac0910d6acc8076899950498b4c43..fd907d07daef9e8337aeed198ef4fd23d069df21 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -53,7 +53,7 @@ class ServiceOptions { ServiceOptions& set_platform(se::Platform* platform); se::Platform* platform() const; - // Set the number of replicas to use when compiling replicated + // Set the default number of replicas to use when compiling replicated // programs. ServiceOptions& set_number_of_replicas(int number_of_replicas); int number_of_replicas() const; @@ -250,8 +250,9 @@ class Service : public ServiceInterface { // ExecutionProfile object which will be filled in with profile data. StatusOr ExecuteAndRegisterResult( Executable* executable, - const absl::Span> arguments, - Backend* backend, const string& result_tag, ExecutionProfile* profile); + absl::Span> arguments, + Backend* backend, const DeviceHandle& device_handle, + const string& result_tag, ExecutionProfile* profile); // Runs the given executables with the given arguments and register the result // from each executable in the allocation tracker. The handles of the result diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h index dbfed628bfcabffe66bef41a82e0e2430897d80d..6bee671056552b83014367889320b748659bbfdf 100644 --- a/tensorflow/compiler/xla/service/service_executable_run_options.h +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -32,12 +32,10 @@ class ServiceExecutableRunOptions { ServiceExecutableRunOptions() : ServiceExecutableRunOptions(ExecutableRunOptions()) {} - explicit ServiceExecutableRunOptions( - ExecutableRunOptions run_options, StreamBorrower borrow_stream = nullptr, - tensorflow::thread::ThreadPool* xla_intra_op_thread_pool = nullptr) + explicit ServiceExecutableRunOptions(ExecutableRunOptions run_options, + StreamBorrower borrow_stream = nullptr) : run_options_(std::move(run_options)), - borrow_stream_(std::move(borrow_stream)), - xla_intra_op_thread_pool_(xla_intra_op_thread_pool) {} + borrow_stream_(std::move(borrow_stream)) {} // Returns reference or pointer to `ExecutableRunOptions` member. const ExecutableRunOptions& run_options() const { return run_options_; } @@ -56,15 +54,9 @@ class ServiceExecutableRunOptions { : Status(tensorflow::error::UNIMPLEMENTED, "No stream cache"); } - // Returns reference to thread pool for execution of XLA ops on CPU backend. - tensorflow::thread::ThreadPool* xla_intra_op_thread_pool() const { - return xla_intra_op_thread_pool_; - } - private: ExecutableRunOptions run_options_; StreamBorrower borrow_stream_; - tensorflow::thread::ThreadPool* xla_intra_op_thread_pool_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 8e571675c79b08efd454ee5e0fe47bacdcf3dbb7..946577d55d43f04fe2dbabb3dd11c3468f2c7edf 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shape_inference.h" -#include #include +#include #include #include #include @@ -50,7 +50,7 @@ bool AllUnique(absl::Span slice) { } Status ExpectArray(const Shape& shape, absl::string_view op_type) { - if (!ShapeUtil::IsArray(shape)) { + if (!shape.IsArray()) { return InvalidArgument("Expected array argument for %s, but got %s.", string(op_type), ShapeUtil::HumanString(shape)); } @@ -70,7 +70,7 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape, const Shape& accumulator_shape = reducer_shape.result(); std::vector accumulator_subshapes; - if (ShapeUtil::IsArray(accumulator_shape)) { + if (accumulator_shape.IsArray()) { if (inputs != 1) { return InvalidArgument( "Reduction function must produce a tuple with %d elements, but " @@ -78,7 +78,7 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape, inputs); } accumulator_subshapes.push_back(&accumulator_shape); - } else if (ShapeUtil::IsTuple(accumulator_shape)) { + } else if (accumulator_shape.IsTuple()) { if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) { return InvalidArgument( "Reduction function must produce a tuple with %d elements, but has " @@ -96,7 +96,7 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape, } for (const Shape* element_shape : accumulator_subshapes) { - if (ShapeUtil::Rank(*element_shape) != 0) { + if (element_shape->rank() != 0) { return InvalidArgument( "Reduction function must return a scalar or tuple of scalars but " "returns shape: %s", @@ -156,17 +156,26 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape, return Status::OK(); } +bool IsTrivialWindowDimension(const WindowDimension& window_dimension) { + return window_dimension.size() == 1 && window_dimension.stride() == 1 && + window_dimension.padding_low() == 0 && + window_dimension.padding_high() == 0 && + window_dimension.window_dilation() == 1 && + window_dimension.base_dilation() == 1; +} + StatusOr InferWindowOutputShape(const Shape& base_shape, const Window& window, PrimitiveType element_type, bool allow_negative_padding) { - if (window.dimensions_size() != ShapeUtil::Rank(base_shape)) { + if (window.dimensions_size() != base_shape.rank()) { return InvalidArgument( "Window has dimension %d but base shape has dimension %d.", - window.dimensions_size(), ShapeUtil::Rank(base_shape)); + window.dimensions_size(), base_shape.rank()); } std::vector output_dimensions(window.dimensions_size()); + std::vector output_is_dynamic(window.dimensions_size()); for (int64 i = 0; i < window.dimensions_size(); ++i) { const auto& dim = window.dimensions(i); if (dim.size() <= 0) { @@ -196,6 +205,12 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, window.DebugString()); } + if (base_shape.is_dynamic_dimension(i) && !IsTrivialWindowDimension(dim)) { + return Unimplemented( + "Dynamic shape is not supported for non trivial window: %s", + window_util::ToString(window)); + } + const int64 dilated_base = window_util::DilatedBound( ShapeUtil::GetDimension(base_shape, i), dim.base_dilation()); const int64 padded_dilated_base = @@ -205,9 +220,11 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, output_dimensions[i] = window_util::StridedBound( padded_dilated_base, dilated_window, dim.stride()); + output_is_dynamic[i] = base_shape.is_dynamic_dimension(i); } - return ShapeUtil::MakeValidatedShape(element_type, output_dimensions); + return ShapeUtil::MakeValidatedShape(element_type, output_dimensions, + output_is_dynamic); } } // namespace @@ -338,7 +355,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (arg_shapes.empty()) { return InvalidArgument("Concatenate expects at least one argument."); } - if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) { + if (dimension < 0 || dimension >= arg_shapes[0]->rank()) { return InvalidArgument("Concatenate dimension out of bounds: %d.", dimension); } @@ -351,12 +368,12 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, element_type = arg_shape->element_type(); continue; } - if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { + if (arg_shape->rank() != shape->rank()) { return InvalidArgument( "Cannot concatenate arrays with different ranks: %d (%s) vs %d " "(%s).", - ShapeUtil::Rank(*arg_shape), ShapeUtil::HumanString(*arg_shape), - ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape)); + arg_shape->rank(), ShapeUtil::HumanString(*arg_shape), shape->rank(), + ShapeUtil::HumanString(*shape)); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) { return InvalidArgument( @@ -364,8 +381,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, PrimitiveType_Name(arg_shape->element_type()), PrimitiveType_Name(shape->element_type())); } - for (int64 dimension_number = 0; - dimension_number < ShapeUtil::Rank(*arg_shape); ++dimension_number) { + for (int64 dimension_number = 0; dimension_number < arg_shape->rank(); + ++dimension_number) { if (arg_shape->dimensions(dimension_number) != shape->dimensions(dimension_number)) { if (dimension_number == dimension) { @@ -401,7 +418,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape), PrimitiveType_Name(new_element_type)); } - if (!ShapeUtil::IsArray(operand_shape) || + if (!operand_shape.IsArray() || !primitive_util::IsArrayType(new_element_type)) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions @@ -424,7 +441,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape), PrimitiveType_Name(new_element_type)); } - if (!ShapeUtil::IsArray(operand_shape) || + if (!operand_shape.IsArray() || !primitive_util::IsArrayType(new_element_type)) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions @@ -472,7 +489,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferPadShape( const Shape& operand_shape, const Shape& padding_value_shape, const PaddingConfig& padding_config) { - if (!ShapeUtil::IsArray(operand_shape)) { + if (!operand_shape.IsArray()) { return InvalidArgument( "Pad operation does not support tuple-shape operands."); } @@ -480,7 +497,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Pad operation does not support non-scalar padding values."); } - if (ShapeUtil::Rank(operand_shape) != padding_config.dimensions_size()) { + if (operand_shape.rank() != padding_config.dimensions_size()) { return InvalidArgument( "The rank of the operand and the padding configuration do not match: " "%s vs %s.", @@ -500,35 +517,40 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, padding_config.ShortDebugString()); } - std::vector dimensions(ShapeUtil::Rank(operand_shape)); + if (!padding_value_shape.is_static()) { + return InvalidArgument("Dynamic padding value is not supported"); + } + + std::vector dimensions(operand_shape.rank()); + std::vector is_dynamic(operand_shape.rank()); for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) { const auto& p = padding_config.dimensions(i); + if (operand_shape.is_dynamic_dimension(i) && p.edge_padding_high() != 0 && + p.edge_padding_low() != 0 && p.interior_padding() != 0) { + return InvalidArgument( + "Dynamic dimension on padding dimension is not supported."); + } dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() + p.edge_padding_high() + std::max(operand_shape.dimensions(i) - 1, 0LL) * p.interior_padding(); + is_dynamic[i] = operand_shape.is_dynamic_dimension(i); } + return ShapeUtil::MakeShape( ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape), - dimensions); + dimensions, is_dynamic); } // Current DotDimensionNumbers Requirements: // // Contracting Dimensions: -// *) Exactly one contracting dimension on both lhs and rhs. +// *) Same number of contracting dimensions on both lhs and rhs. // *) Contracting dimension size must be the same on both lhs and rhs. -// *) Contracting dimension numbers do not need to be the same (i.e. transposes -// are passed on to emitter implementations). // // Batch Dimensions: // *) Same number of batch dimensions on both lhs and rhs. -// *) Same batch dimension numbers (and sizes) on both lhs and rhs. -// *) Batch dimension numbers must be ordered before contracting and -// non-contracting/non-batch dimension numbers. -// -// Non-Contracting-Non-Batch Dimensions: -// *) Can be 0 (matrix-vector) or 1 (matrix-matrix). +// *) Same batch dimension sizes on both lhs and rhs. // namespace { @@ -541,9 +563,8 @@ Status ValidateDotDimensionNumbers( absl::Span contracting_dims, absl::Span batch_dims) -> bool { auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; }; - return std::all_of(contracting_dims.begin(), contracting_dims.end(), - in_range) && - std::all_of(batch_dims.begin(), batch_dims.end(), in_range); + return absl::c_all_of(contracting_dims, in_range) && + absl::c_all_of(batch_dims, in_range); }; absl::Span lhs_contracting_dimensions = @@ -555,9 +576,9 @@ Status ValidateDotDimensionNumbers( absl::Span rhs_batch_dimensions = AsInt64Slice(dimension_numbers.rhs_batch_dimensions()); - if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions, + if (!dims_in_range(lhs.rank(), lhs_contracting_dimensions, lhs_batch_dimensions) || - !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions, + !dims_in_range(rhs.rank(), rhs_contracting_dimensions, rhs_batch_dimensions)) { return InvalidArgument("A dimension number is out of range in Dot: %s.", dimension_numbers.DebugString()); @@ -570,9 +591,8 @@ Status ValidateDotDimensionNumbers( auto is_unique = [&dim_set](int64 i) -> bool { return dim_set.insert(i).second; }; - return std::all_of(contracting_dims.begin(), contracting_dims.end(), - is_unique) && - std::all_of(batch_dims.begin(), batch_dims.end(), is_unique); + return absl::c_all_of(contracting_dims, is_unique) && + absl::c_all_of(batch_dims, is_unique); }; if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) || @@ -581,36 +601,6 @@ Status ValidateDotDimensionNumbers( dimension_numbers.DebugString()); } - // Check that the count of non-contracting-non-batch dimensions is in {0, 1}. - const int64 lhs_non_contracting_non_batch_dims = - ShapeUtil::Rank(lhs) - - dimension_numbers.lhs_contracting_dimensions_size() - - dimension_numbers.lhs_batch_dimensions_size(); - const int64 rhs_non_contracting_non_batch_dims = - ShapeUtil::Rank(rhs) - - dimension_numbers.rhs_contracting_dimensions_size() - - dimension_numbers.rhs_batch_dimensions_size(); - if (lhs_non_contracting_non_batch_dims < 0 || - lhs_non_contracting_non_batch_dims > 1 || - rhs_non_contracting_non_batch_dims < 0 || - rhs_non_contracting_non_batch_dims > 1) { - return InvalidArgument( - "Batch and contracting dimension number mismatch with rank."); - } - - // Check that batch dimension numbers are ordered before all others, and - // that they are monotonically increasing. - std::vector batch_dim_numbers(lhs_batch_dimensions.size()); - std::iota(batch_dim_numbers.begin(), batch_dim_numbers.end(), 0); - if (!std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), - lhs_batch_dimensions.begin()) || - !std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), - rhs_batch_dimensions.begin())) { - return InvalidArgument( - "Batch dimension numbers must precede non-batch dimensions and be" - "monotonically increasing."); - } - return Status::OK(); } @@ -637,28 +627,33 @@ Status ValidateDotDimensionNumbers( return fail("Element types do not match."); } - if ((ShapeUtil::Rank(lhs) < 1) || (ShapeUtil::Rank(rhs) < 1)) { + if ((lhs.rank() < 1) || (rhs.rank() < 1)) { return fail("Dot only supports rank 1 or above."); } // Validate basic properties of dot dimension numbers. TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers)); - // Check that there is only one contracting dimension for both lhs and rhs. + // Check that number of contracting dimensions match. if (dimension_numbers.lhs_contracting_dimensions_size() != - dimension_numbers.rhs_contracting_dimensions_size() || - dimension_numbers.lhs_contracting_dimensions_size() != 1) { - return fail("Must specify one contracting dimension for both lhs and rhs."); + dimension_numbers.rhs_contracting_dimensions_size()) { + return fail( + "Must specify the same number of contracting dimensions for lhs and " + "rhs."); } - // Check that contracting dimension sizes match. - const int64 lhs_contracting_dimension = - dimension_numbers.lhs_contracting_dimensions(0); - const int64 rhs_contracting_dimension = - dimension_numbers.rhs_contracting_dimensions(0); - if (lhs.dimensions(lhs_contracting_dimension) != - rhs.dimensions(rhs_contracting_dimension)) { - return fail("Contracting dimension sizes do not match."); + for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size(); + ++i) { + const int64 lhs_contracting_dimension = + dimension_numbers.lhs_contracting_dimensions(i); + const int64 rhs_contracting_dimension = + dimension_numbers.rhs_contracting_dimensions(i); + if (lhs.dimensions(lhs_contracting_dimension) != + rhs.dimensions(rhs_contracting_dimension) || + lhs.is_dynamic_dimension(lhs_contracting_dimension) != + rhs.is_dynamic_dimension(rhs_contracting_dimension)) { + return fail("Contracting dimension sizes do not match."); + } } // Check that number of batch dimensions match. @@ -669,11 +664,12 @@ Status ValidateDotDimensionNumbers( // Check that batch dimension numbers and sizes match. for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) { - if (dimension_numbers.lhs_batch_dimensions(i) != - dimension_numbers.rhs_batch_dimensions(i) || - lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != - rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) { - return fail("Batch dimension numbers and sizes must match for lhs/rhs."); + if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != + rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i)) || + lhs.is_dynamic_dimension(dimension_numbers.lhs_batch_dimensions(i)) != + rhs.is_dynamic_dimension( + dimension_numbers.rhs_batch_dimensions(i))) { + return fail("Batch dimension sizes must match for lhs/rhs."); } } @@ -683,21 +679,29 @@ Status ValidateDotDimensionNumbers( // Generate the result dimensions in order, rhs dimensions followed by lhs // dimensions except the contracted and batch dimensions. std::vector dimensions; - std::unordered_set rhs_batch_dims( - dimension_numbers.rhs_batch_dimensions().begin(), - dimension_numbers.rhs_batch_dimensions().end()); - for (int64 i = 0; i < ShapeUtil::Rank(lhs); i++) { - if (i != lhs_contracting_dimension) { + std::vector is_dynamic; + for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) { + dimensions.push_back(lhs.dimensions(lhs_dim)); + is_dynamic.push_back(lhs.is_dynamic_dimension(lhs_dim)); + } + for (int64 i = 0; i < lhs.rank(); i++) { + if (!absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(), + i) && + !absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) { dimensions.push_back(lhs.dimensions(i)); + is_dynamic.push_back(lhs.is_dynamic_dimension(i)); } } - for (int64 i = 0; i < ShapeUtil::Rank(rhs); i++) { - if (i != rhs_contracting_dimension && rhs_batch_dims.count(i) == 0) { + for (int64 i = 0; i < rhs.rank(); i++) { + if (!absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(), + i) && + !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) { dimensions.push_back(rhs.dimensions(i)); + is_dynamic.push_back(rhs.is_dynamic_dimension(i)); } } Shape result = ShapeUtil::MakeShape( - ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions); + ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions, is_dynamic); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result)); VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result); @@ -708,20 +712,24 @@ Status ValidateDotDimensionNumbers( ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& lhs, const Shape& rhs) { - TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)); + TF_RET_CHECK(lhs.rank() == rhs.rank()); // The shapes have to be compatible. That is, if some dimension d has a // different size in the two shapes, one of them has to be 1 (a "degenerate" // dimension). In that case, the output shape has the non-1 dimension size // from the lhs/rhs pair in every index. - std::vector output_dimensions(ShapeUtil::Rank(lhs)); - for (int64 i = 0; i < ShapeUtil::Rank(lhs); ++i) { + std::vector output_dimensions(lhs.rank()); + std::vector output_dimensions_is_dynamic(lhs.rank()); + for (int64 i = 0; i < lhs.rank(); ++i) { if (lhs.dimensions(i) == rhs.dimensions(i)) { output_dimensions[i] = lhs.dimensions(i); + output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i); } else if (lhs.dimensions(i) == 1) { output_dimensions[i] = rhs.dimensions(i); + output_dimensions_is_dynamic[i] = rhs.is_dynamic_dimension(i); } else if (rhs.dimensions(i) == 1) { output_dimensions[i] = lhs.dimensions(i); + output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i); } else { return InvalidArgument( "Binary op %s with incompatible shapes: %s and %s.", @@ -730,7 +738,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), - output_dimensions); + output_dimensions, output_dimensions_is_dynamic); } /* static */ StatusOr ShapeInference::InferInDimBroadcastShape( @@ -743,13 +751,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument("Automatic shape inference not supported: %s and %s", ShapeUtil::HumanString(smaller_shape), ShapeUtil::HumanString(larger_shape)); - } else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) { + } else if (broadcast_dimensions.size() != smaller_shape.rank()) { return InvalidArgument( "Size of broadcast_dimensions has to match lower-rank operand's " "rank; " " lower-rank operand's rank is %d, size of broadcast_dimensions is " "%u.", - ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size()); + smaller_shape.rank(), broadcast_dimensions.size()); } // broadcast_dimensions is a sequence of dimensions; its length is equal to @@ -809,6 +817,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } int64 small_dimension_size = smaller_shape.dimensions(i); int64 large_dimension_size = larger_shape.dimensions(dimension_to_match); + bool small_is_dynamic = smaller_shape.is_dynamic_dimension(i); + bool large_is_dynamic = + larger_shape.is_dynamic_dimension(dimension_to_match); // Dimension sizes must be compatible: match or be degenerate (degenerate // case is handled by degenerate dimension broadcasting which occurs after // InDim broadcasting). @@ -820,6 +831,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(smaller_shape), ShapeUtil::HumanString(larger_shape)); } + if (small_is_dynamic != large_is_dynamic) { + return InvalidArgument( + "Broadcast dimension %d dynamism mismatch: %s and %s.", i, + ShapeUtil::HumanString(smaller_shape), + ShapeUtil::HumanString(larger_shape)); + } // Make sure the broadcast dimensions are listed in a strictly increasing // order. if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) { @@ -829,6 +846,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } output_shape.set_dimensions(dimension_to_match, small_dimension_size); + output_shape.set_dynamic_dimension(dimension_to_match, small_is_dynamic); } return output_shape; @@ -847,8 +865,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(rhs)); } - if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) { - std::vector identity_dims(ShapeUtil::Rank(lhs)); + if (lhs.rank() == rhs.rank()) { + std::vector identity_dims(lhs.rank()); std::iota(identity_dims.begin(), identity_dims.end(), 0); if (!broadcast_dimensions.empty() && broadcast_dimensions != identity_dims) { @@ -865,15 +883,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs)); } - if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) { + if (lhs.rank() == rhs.rank()) { return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs); } else { // Ranks do not match, so perform InDim broadcasting using // broadcast_dimensions. Scalar broadcasting is a special case of this. - const Shape& larger_shape = - ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? lhs : rhs; - const Shape& smaller_shape = - ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs; + const Shape& larger_shape = lhs.rank() > rhs.rank() ? lhs : rhs; + const Shape& smaller_shape = lhs.rank() > rhs.rank() ? rhs : lhs; // After InDim broadcasting, perform degenerate dimensions broadcasting. TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape, @@ -942,6 +958,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, broadcast_dimensions)); if (lhs.element_type() == F32 && rhs.element_type() == F32) { return ShapeUtil::ChangeElementType(shape, C64); + } else if (lhs.element_type() == F64 && rhs.element_type() == F64) { + return ShapeUtil::ChangeElementType(shape, C128); } else { return Unimplemented("Complex component type is not implemented."); } @@ -1162,12 +1180,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == Status::OK()); - if (feature_index >= ShapeUtil::Rank(operand_shape)) { + if (feature_index >= operand_shape.rank()) { return InvalidArgument( "Expected feature_index of batch-norm-training to be " "smaller than the rank of operand_shape; " "got feature_index %d, and rank %d.", - feature_index, ShapeUtil::Rank(operand_shape)); + feature_index, operand_shape.rank()); } if (feature_index < 0) { @@ -1177,25 +1195,25 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, feature_index); } - if (ShapeUtil::Rank(operand_shape) < 1) { + if (operand_shape.rank() < 1) { return InvalidArgument( "Expected the rank of operand to " "batch-norm-training to be at least 1; got %d.", - ShapeUtil::Rank(operand_shape)); + operand_shape.rank()); } - if (ShapeUtil::Rank(offset_shape) != 1) { + if (offset_shape.rank() != 1) { return InvalidArgument( "Offset input of batch-norm-training must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(offset_shape)); + offset_shape.rank()); } - if (ShapeUtil::Rank(scale_shape) != 1) { + if (scale_shape.rank() != 1) { return InvalidArgument( "Scale input of batch-norm-training must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(scale_shape)); + scale_shape.rank()); } if (!ShapeUtil::ElementIsFloating(operand_shape)) { @@ -1272,12 +1290,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) == Status::OK()); - if (feature_index >= ShapeUtil::Rank(operand_shape)) { + if (feature_index >= operand_shape.rank()) { return InvalidArgument( "Expected feature_index of batch-norm-inference to be " "smaller than the rank of operand_shape; " "got feature_index %d, and rank %d.", - feature_index, ShapeUtil::Rank(operand_shape)); + feature_index, operand_shape.rank()); } if (feature_index < 0) { @@ -1287,25 +1305,25 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, feature_index); } - if (ShapeUtil::Rank(operand_shape) < 1) { + if (operand_shape.rank() < 1) { return InvalidArgument( "Expected the rank of operand to " "batch-norm-inference to be at least 1; got %d.", - ShapeUtil::Rank(operand_shape)); + operand_shape.rank()); } - if (ShapeUtil::Rank(offset_shape) != 1) { + if (offset_shape.rank() != 1) { return InvalidArgument( "Offset input of batch-norm-inference must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(offset_shape)); + offset_shape.rank()); } - if (ShapeUtil::Rank(scale_shape) != 1) { + if (scale_shape.rank() != 1) { return InvalidArgument( "Scale input of batch-norm-inference must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(scale_shape)); + scale_shape.rank()); } if (!ShapeUtil::ElementIsFloating(operand_shape)) { @@ -1417,41 +1435,41 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RETURN_IF_ERROR( ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape)); - if (feature_index >= ShapeUtil::Rank(operand_shape)) { + if (feature_index >= operand_shape.rank()) { return InvalidArgument( "Expected feature_index of batch-norm-grad to be " "smaller than the rank of operand_shape; " "got feature_index %d, and rank %d.", - feature_index, ShapeUtil::Rank(operand_shape)); + feature_index, operand_shape.rank()); } - if (ShapeUtil::Rank(operand_shape) != ShapeUtil::Rank(output_grad_shape)) { + if (operand_shape.rank() != output_grad_shape.rank()) { return InvalidArgument( "Expected operand_shape of batch-norm-grad to have the same rank as" " output_grad_shape; got rank(oprand_shape) %d, and" " rank(output_grad_shape) %d.", - ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape)); + operand_shape.rank(), output_grad_shape.rank()); } - if (ShapeUtil::Rank(mean_shape) != 1) { + if (mean_shape.rank() != 1) { return InvalidArgument( "Mean input of batch-norm-grad must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(mean_shape)); + mean_shape.rank()); } - if (ShapeUtil::Rank(scale_shape) != 1) { + if (scale_shape.rank() != 1) { return InvalidArgument( "Scale input of batch-norm-grad must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(scale_shape)); + scale_shape.rank()); } - if (ShapeUtil::Rank(var_shape) != 1) { + if (var_shape.rank() != 1) { return InvalidArgument( "Var input of batch-norm-grad must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(var_shape)); + var_shape.rank()); } if (!ShapeUtil::ElementIsFloating(operand_shape)) { @@ -1538,7 +1556,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } // Verify operand_shape and output_grad_shape have same bounds. - for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { + for (int64 i = 0; i < operand_shape.rank(); ++i) { if (ShapeUtil::GetDimension(operand_shape, i) != ShapeUtil::GetDimension(output_grad_shape, i)) { return InvalidArgument( @@ -1573,6 +1591,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, batch_group_count); } + if (batch_group_count > 1 && feature_group_count > 1) { + return InvalidArgument( + "both batch_group_count %d and feature_group_count %d cannot be " + "greater than 1", + batch_group_count, feature_group_count); + } + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Convolution with different element types: %s and %s.", @@ -1603,12 +1628,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } const int num_dims = num_spatial_dims + 2; - if (ShapeUtil::Rank(lhs) != num_dims) { + if (lhs.rank() != num_dims) { return InvalidArgument( "The LHS argument to a convolution should have rank %d; lhs: %s.", num_dims, ShapeUtil::HumanString(lhs)); } - if (ShapeUtil::Rank(rhs) != num_dims) { + if (rhs.rank() != num_dims) { return InvalidArgument( "The RHS argument to a convolution should have rank %d; rhs: %s.", num_dims, ShapeUtil::HumanString(rhs)); @@ -1623,29 +1648,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, input_dnums[1] = dnums.input_feature_dimension(); std::copy(dnums.input_spatial_dimensions().begin(), dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2); - std::sort(input_dnums.begin(), input_dnums.end()); + absl::c_sort(input_dnums); std::vector window_dnums(num_dims); window_dnums[0] = dnums.kernel_input_feature_dimension(); window_dnums[1] = dnums.kernel_output_feature_dimension(); std::copy(dnums.kernel_spatial_dimensions().begin(), dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2); - std::sort(window_dnums.begin(), window_dnums.end()); + absl::c_sort(window_dnums); std::vector output_dnums(num_dims); output_dnums[0] = dnums.output_batch_dimension(); output_dnums[1] = dnums.output_feature_dimension(); std::copy(dnums.output_spatial_dimensions().begin(), dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2); - std::sort(output_dnums.begin(), output_dnums.end()); + absl::c_sort(output_dnums); std::vector expected_dnums(num_dims); std::iota(expected_dnums.begin(), expected_dnums.end(), 0); const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; }; - if (!std::all_of(input_dnums.begin(), input_dnums.end(), in_range) || - !std::all_of(window_dnums.begin(), window_dnums.end(), in_range) || - !std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) { + if (!absl::c_all_of(input_dnums, in_range) || + !absl::c_all_of(window_dnums, in_range) || + !absl::c_all_of(output_dnums, in_range)) { return InvalidArgument( "A dimension number is out of range in convolution: %s.", dnums.DebugString()); @@ -1686,6 +1711,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 kernel_output_features = rhs.dimensions(dnums.kernel_output_feature_dimension()); + if (batch_group_count > 1 && input_batch % kernel_output_features != 0) { + return InvalidArgument( + "Expected output feature dimension (value %d) to be divisible by " + "input_batch (value %d) for batch group count %d; " + "got (%s, %s)\n" + "Dimension numbers: {%s}.", + kernel_output_features, input_batch, batch_group_count, + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), + dnums.DebugString()); + } + if (input_features % feature_group_count != 0 || input_features / feature_group_count != kernel_input_features) { return InvalidArgument( @@ -1747,8 +1783,33 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, dimensions[dnums.output_spatial_dimensions(i)] = window_output_shape.dimensions(i); } + std::vector is_dynamic(num_dims); + for (int i = 0; i < num_dims; i++) { + if (lhs.is_dynamic_dimension(i)) { + if (i == dnums.input_batch_dimension()) { + is_dynamic[dnums.output_batch_dimension()] = true; + } else if (i == dnums.input_feature_dimension()) { + // Input feature dimension is a contracting dimension, which does not + // affect the output dimension size. So we need to do nothing. + } else { + return InvalidArgument( + "Dynamic Spatial Convolution is not supported: lhs shape is %s ", + lhs.ToString()); + } + } + if (rhs.is_dynamic_dimension(i)) { + if (i == dnums.kernel_input_feature_dimension()) { + // Kernel feature dimension does not affect the output dimension size. + // So we need to do nothing. + } else { + return InvalidArgument( + "Dynamic Spatial Convolution is not supported: rhs shape is %s ", + rhs.ToString()); + } + } + } return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), - dimensions); + dimensions, is_dynamic); } /* static */ StatusOr ShapeInference::InferFftShape( @@ -1769,7 +1830,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case FFT: case IFFT: if (in.element_type() != C64) { - return InvalidArgument("%s requires C64 input type, found %s.", + return InvalidArgument("%s requires complex input type, found %s.", FftType_Name(fft_type), PrimitiveType_Name(in.element_type())); } @@ -1853,12 +1914,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& shape, int64 split_dimension, int64 concat_dimension, int64 split_count) { TF_RET_CHECK(split_count > 0); - if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) { + if (split_dimension >= shape.rank() || split_dimension < 0) { return InvalidArgument( "AllToAll split_dimension %d is out-of-bounds in shape %s.", split_dimension, ShapeUtil::HumanString(shape)); } - if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) { + if (concat_dimension >= shape.rank() || concat_dimension < 0) { return InvalidArgument( "AllToAll concat_dimension %d is out-of-bounds in shape %s.", concat_dimension, ShapeUtil::HumanString(shape)); @@ -1896,7 +1957,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferCollectivePermuteShape( const Shape& shape) { - TF_RET_CHECK(ShapeUtil::IsArray(shape)); + TF_RET_CHECK(shape.IsArray()); return shape; } @@ -1920,7 +1981,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, for (int64 i = 1; i < num_reduced_args; ++i) { if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) { return InvalidArgument( - "All reduced tensors must have the sime dimension. Tensor 0 has " + "All reduced tensors must have the same dimension. Tensor 0 has " "shape %s, Tensor %d has shape %s", ShapeUtil::HumanString(*reduced_args[0]), i, ShapeUtil::HumanString(*reduced_args[i])); @@ -1932,7 +1993,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, // doesn't matter which one we choose. const Shape& arg = *reduced_args[0]; for (int64 dimension : dimensions_to_reduce) { - if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) { + if (dimension >= arg.rank() || dimension < 0) { return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.", dimension, ShapeUtil::HumanString(arg)); } @@ -1949,20 +2010,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, std::set dimensions_to_reduce_set(dimensions_to_reduce.begin(), dimensions_to_reduce.end()); std::vector new_dimensions; - for (int i = 0; i < ShapeUtil::Rank(arg); ++i) { + std::vector new_is_dynamic; + for (int i = 0; i < arg.rank(); ++i) { if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) { new_dimensions.push_back(arg.dimensions(i)); + new_is_dynamic.push_back(arg.is_dynamic_dimension(i)); } } if (ShapeUtil::IsScalar(to_apply.result())) { return ShapeUtil::MakeShape(to_apply.result().element_type(), - new_dimensions); + new_dimensions, new_is_dynamic); } else { std::vector result_subshapes; for (const Shape& subshape : to_apply.result().tuple_shapes()) { - result_subshapes.push_back( - ShapeUtil::MakeShape(subshape.element_type(), new_dimensions)); + result_subshapes.push_back(ShapeUtil::MakeShape( + subshape.element_type(), new_dimensions, new_is_dynamic)); } return ShapeUtil::MakeTupleShape(result_subshapes); } @@ -2036,12 +2099,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(source_shape), ShapeUtil::HumanString(window_result_shape)); } + return operand_shape; } /* static */ StatusOr ShapeInference::InferGetDimensionSizeShape( const Shape& shape, int64 dimension) { - if (dimension < 0 || dimension >= ShapeUtil::Rank(shape)) { + if (dimension < 0 || dimension >= shape.rank()) { return InvalidArgument("GetDimensionSize dimension out of bounds: %d.", dimension); } @@ -2083,10 +2147,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, starts.size(), strides.size())); } - if (starts.size() != ShapeUtil::Rank(arg)) { + if (starts.size() != arg.rank()) { return InvalidArgument( "Slice index count does not match argument rank: %u vs %d.", - starts.size(), ShapeUtil::Rank(arg)); + starts.size(), arg.rank()); } std::vector sizes; @@ -2121,41 +2185,87 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferDynamicSliceShape( - const Shape& operand_shape, const Shape& start_indices_shape, - absl::Span slice_sizes) { + const Shape& operand_shape, absl::Span start_index_shapes, + absl::Span slice_sizes, bool allow_scalar_indices) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); - TF_RETURN_IF_ERROR( - ExpectArray(start_indices_shape, "start indices of dynamic slice")); + auto number_of_indices = start_index_shapes.size(); + // TODO(b/118437727): Remove this path. + if (!allow_scalar_indices || + (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) { + if (number_of_indices != 1) { + return InvalidArgument( + "Dynamic slice should have exactly 1 index operand, has %d.", + number_of_indices); + } - VLOG(2) << StrFormat( - "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", - ShapeUtil::HumanString(operand_shape), - ShapeUtil::HumanString(start_indices_shape), StrJoin(slice_sizes, ", ")); + const Shape& start_indices_shape = start_index_shapes[0]; + VLOG(2) << StrFormat( + "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(start_indices_shape), + StrJoin(slice_sizes, ", ")); - if (ShapeUtil::Rank(start_indices_shape) != 1) { - return InvalidArgument( - "Dynamic slice start indices of rank %d must be rank1.", - ShapeUtil::Rank(start_indices_shape)); - } + TF_RETURN_IF_ERROR( + ExpectArray(start_indices_shape, "start indices of dynamic slice")); - if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { - return InvalidArgument( - "Dynamic slice start indices must be of integral type."); - } + if (start_indices_shape.rank() != 1) { + return InvalidArgument( + "Dynamic slice start indices of rank %d must be rank1.", + start_indices_shape.rank()); + } - const int64 start_num_dims = start_indices_shape.dimensions(0); - if (ShapeUtil::Rank(operand_shape) != start_num_dims) { - return InvalidArgument( - "Dynamic slice start number of dimensions %d (%s) must match rank " - "%d of slice input (%s).", - start_num_dims, ShapeUtil::HumanString(start_indices_shape), - ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape)); + if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { + return InvalidArgument( + "Dynamic slice start indices must be of integral type."); + } + + const int64 start_num_dims = start_indices_shape.dimensions(0); + if (operand_shape.rank() != start_num_dims) { + return InvalidArgument( + "Dynamic slice start number of dimensions %d (%s) must match rank " + "%d of slice input (%s).", + start_num_dims, ShapeUtil::HumanString(start_indices_shape), + operand_shape.rank(), ShapeUtil::HumanString(operand_shape)); + } + } else { + VLOG(2) << StrFormat("slicing shape %s a with slice_sizes={%s}", + ShapeUtil::HumanString(operand_shape), + StrJoin(slice_sizes, ", ")); + + if (operand_shape.rank() != number_of_indices) { + return InvalidArgument( + "Dynamic slice start number of dimensions %d must match rank " + "%d of slice input (%s).", + number_of_indices, operand_shape.rank(), + ShapeUtil::HumanString(operand_shape)); + } + + if (number_of_indices > 0) { + const Shape& first_index_shape = start_index_shapes[0]; + if (!ShapeUtil::IsScalar(first_index_shape)) { + return InvalidArgument("Dynamic slice indices must be scalar, not %s.", + ShapeUtil::HumanString(first_index_shape)); + } + if (!ShapeUtil::ElementIsIntegral(first_index_shape)) { + return InvalidArgument( + "Dynamic slice start indices must be of integral type."); + } + for (const Shape& index_shape : start_index_shapes) { + if (!ShapeUtil::Compatible(first_index_shape, index_shape)) { + return InvalidArgument( + "Dynamic slice start indices must all have the same shape, got " + "mismatching indices with shapes %s and %s.", + ShapeUtil::HumanString(first_index_shape), + ShapeUtil::HumanString(index_shape)); + } + } + } } - if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) { + if (slice_sizes.size() != operand_shape.rank()) { return InvalidArgument( "Dynamic slice index count does not match argument rank: %u vs %d.", - slice_sizes.size(), ShapeUtil::Rank(operand_shape)); + slice_sizes.size(), operand_shape.rank()); } for (int64 dim = 0; dim < slice_sizes.size(); ++dim) { @@ -2178,46 +2288,92 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferDynamicUpdateSliceShape( const Shape& operand_shape, const Shape& update_shape, - const Shape& start_indices_shape) { + absl::Span start_index_shapes, bool allow_scalar_indices) { TF_RETURN_IF_ERROR( ExpectArray(operand_shape, "operand of dynamic update slice")); TF_RETURN_IF_ERROR( ExpectArray(update_shape, "update of dynamic update slice")); - TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape, - "start indices of dynamic update slice")); - VLOG(2) << StrFormat( - "updating slice of shape %s at dynamic start_indices %s with update " - "shape %s", - ShapeUtil::HumanString(operand_shape), - ShapeUtil::HumanString(start_indices_shape), - ShapeUtil::HumanString(update_shape)); + auto number_of_indices = start_index_shapes.size(); + // TODO(b/118437727): Remove this path. + if (!allow_scalar_indices || + (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) { + if (number_of_indices != 1) { + return InvalidArgument( + "Dynamic update slice should have exactly 1 index operand, has %d.", + number_of_indices); + } + const Shape& start_indices_shape = start_index_shapes[0]; + TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape, + "start indices of dynamic update slice")); - if (ShapeUtil::Rank(start_indices_shape) != 1) { - return InvalidArgument( - "Dynamic update slice start indices of rank %d must be rank1.", - ShapeUtil::Rank(start_indices_shape)); - } + VLOG(2) << StrFormat( + "updating slice of shape %s at dynamic start_indices %s with update " + "shape %s", + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(start_indices_shape), + ShapeUtil::HumanString(update_shape)); - if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { - return InvalidArgument( - "Dynamic update slice start indices must be of integral type."); - } + if (start_indices_shape.rank() != 1) { + return InvalidArgument( + "Dynamic update slice start indices of rank %d must be rank1.", + start_indices_shape.rank()); + } - const int64 start_num_dims = start_indices_shape.dimensions(0); - if (ShapeUtil::Rank(operand_shape) != start_num_dims) { - return InvalidArgument( - "Dynamic update slice start number of dimensions %d (%s) must match " - "rank %d of slice input (%s).", - start_num_dims, ShapeUtil::HumanString(start_indices_shape), - ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape)); + if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { + return InvalidArgument( + "Dynamic update slice start indices must be of integral type."); + } + + const int64 start_num_dims = start_indices_shape.dimensions(0); + if (operand_shape.rank() != start_num_dims) { + return InvalidArgument( + "Dynamic update slice start number of dimensions %d (%s) must match " + "rank %d of slice input (%s).", + start_num_dims, ShapeUtil::HumanString(start_indices_shape), + operand_shape.rank(), ShapeUtil::HumanString(operand_shape)); + } + } else { + VLOG(2) << StrFormat("updating slice of shape %s with update shape %s", + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(update_shape)); + + if (operand_shape.rank() != number_of_indices) { + return InvalidArgument( + "Dynamic update slice start number of dimensions %d must match rank " + "%d of slice input (%s).", + number_of_indices, operand_shape.rank(), + ShapeUtil::HumanString(operand_shape)); + } + + if (number_of_indices > 0) { + const Shape& first_index_shape = start_index_shapes[0]; + if (!ShapeUtil::IsScalar(first_index_shape)) { + return InvalidArgument( + "Dynamic update slice indices must be scalar, not %s.", + ShapeUtil::HumanString(first_index_shape)); + } + if (!ShapeUtil::ElementIsIntegral(first_index_shape)) { + return InvalidArgument( + "Dynamic update slice start indices must be of integral type."); + } + for (const Shape& index_shape : start_index_shapes) { + if (!ShapeUtil::Compatible(first_index_shape, index_shape)) { + return InvalidArgument( + "Dynamic update slice start indices must all have the same " + "shape, got mismatching indices with shapes %s and %s.", + ShapeUtil::HumanString(first_index_shape), + ShapeUtil::HumanString(index_shape)); + } + } + } } - if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) { + if (update_shape.rank() != operand_shape.rank()) { return InvalidArgument( "Dynamic update slice update rank does not match argument rank: " "%d vs %d.", - ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape)); + update_shape.rank(), operand_shape.rank()); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, @@ -2229,7 +2385,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, PrimitiveType_Name(update_shape.element_type())); } - for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) { + for (int64 dim = 0; dim < operand_shape.rank(); ++dim) { const int64 input_dim_size = operand_shape.dimensions(dim); const int64 update_dim_size = update_shape.dimensions(dim); if (update_dim_size < 0) { @@ -2255,7 +2411,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument("a dimension number is duplicated in reverse"); } for (int64 dimension : dimensions) { - if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) { + if (dimension >= operand_shape.rank() || dimension < 0) { return InvalidArgument( "One of the reverse dimensions (%d) is out-of-bounds in shape %s.", dimension, ShapeUtil::HumanString(operand_shape)); @@ -2266,7 +2422,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferGetTupleElementShape( const Shape& arg, int64 index) { - if (!ShapeUtil::IsTuple(arg)) { + if (!arg.IsTuple()) { return InvalidArgument( "Cannot infer shape: attempting to index into non-tuple: %s.", ShapeUtil::HumanString(arg)); @@ -2302,7 +2458,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, }; // Check the shapes of computation parameters and return types. - if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) { + if (!ShapeUtil::Equal(condition.result(), ShapeUtil::MakeShape(PRED, {}))) { return InvalidArgument("Condition must return a boolean; got %s.", shape_string()); } @@ -2322,7 +2478,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& predicate, const Shape& true_operand, const Shape& false_operand, const ProgramShape& true_computation, const ProgramShape& false_computation) { - if (!ShapeUtil::ShapeIs(predicate, PRED, {})) { + if (!ShapeUtil::Equal(predicate, ShapeUtil::MakeShape(PRED, {}))) { return InvalidArgument("Predicate must be a boolean; got %s.", ShapeUtil::HumanString(predicate)); } @@ -2397,8 +2553,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, absl::Span broadcast_dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast")); TF_RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast")); - const int64 operand_rank = ShapeUtil::Rank(operand_shape); - const int64 output_rank = ShapeUtil::Rank(output_shape); + const int64 operand_rank = operand_shape.rank(); + const int64 output_rank = output_shape.rank(); if (operand_rank > output_rank) { return InvalidArgument( "InDim style broadcast must be to an equal or higher ranked shape; " @@ -2426,6 +2582,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, i, operand_shape.dimensions(i), broadcast_dimensions[i], output_shape.dimensions(broadcast_dimensions[i])); } + if (operand_shape.is_dynamic_dimension(i) != + output_shape.is_dynamic_dimension(broadcast_dimensions[i])) { + return InvalidArgument( + "Broadcast input and output dynamism mismatch: %s and %s", + operand_shape.ToString(), output_shape.ToString()); + } // Make sure the broadcast dimensions are listed in a strictly increasing // order. if (i > 0 && broadcast_dimensions[i - 1] >= broadcast_dimensions[i]) { @@ -2457,9 +2619,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(inferred_shape)); } - std::vector indices(ShapeUtil::Rank(operand)); + std::vector indices(operand.rank()); std::iota(indices.begin(), indices.end(), 0); - if (dimensions.size() != ShapeUtil::Rank(operand) || + if (dimensions.size() != operand.rank() || !std::is_permutation(dimensions.begin(), dimensions.end(), indices.begin())) { return InvalidArgument( @@ -2468,6 +2630,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, StrJoin(dimensions, ","), ShapeUtil::HumanString(operand)); } + std::vector> unmodified_dims = + ShapeUtil::DimensionsUnmodifiedByReshape(operand, inferred_shape); + for (auto& unmodified : unmodified_dims) { + if (operand.is_dynamic_dimension(unmodified.first)) { + inferred_shape.set_dynamic_dimension(unmodified.second, true); + } + } + return inferred_shape; } @@ -2475,9 +2645,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& operand, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); - std::vector indices(ShapeUtil::Rank(operand)); + std::vector indices(operand.rank()); std::iota(indices.begin(), indices.end(), 0); - if (dimensions.size() != ShapeUtil::Rank(operand) || + if (dimensions.size() != operand.rank() || !std::is_permutation(dimensions.begin(), dimensions.end(), indices.begin())) { return InvalidArgument( @@ -2548,12 +2718,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, // dimensions as on_true and on_false. return ShapeUtil::ChangeElementType( on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false)); - } else { - return InvalidArgument( - "Select operation with non-scalar predicate with dimensionality " - " different from the other operands: %s.", - ShapeUtil::HumanString(pred)); } + return InvalidArgument( + "Select operation with non-scalar predicate with dimensionality " + "different from the other operands: %s.", + ShapeUtil::HumanString(pred)); } /* static */ StatusOr ShapeInference::InferTupleSelectShape( @@ -2829,7 +2998,7 @@ Status ValidateScatterDimensionNumbers( "update_window_dims in scatter op must not repeat; got: %s.", StrJoin(dim_numbers.update_window_dims(), ", ")); } - const int64 updates_rank = ShapeUtil::Rank(updates_shape); + const int64 updates_rank = updates_shape.rank(); for (int64 window_dim : dim_numbers.update_window_dims()) { if (window_dim < 0 || window_dim >= updates_rank) { return InvalidArgument( @@ -2863,10 +3032,10 @@ Status ValidateScatterDimensionNumbers( // Validate window size. auto window_size = dim_numbers.update_window_dims_size() + dim_numbers.inserted_window_dims_size(); - if (window_size != ShapeUtil::Rank(operand_shape)) { + if (window_size != operand_shape.rank()) { return InvalidArgument( "Scatter op has window of size %d; doesn't match operand of rank %d.", - window_size, ShapeUtil::Rank(operand_shape)); + window_size, operand_shape.rank()); } // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers. @@ -2951,10 +3120,9 @@ Status ValidateScatterDimensionNumbers( int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 + scatter_dim_numbers.update_window_dims_size(); - if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) { + if (updates_shape.rank() != expected_updates_rank) { return InvalidArgument("Updates tensor must be of rank %d; got %d.", - expected_updates_rank, - ShapeUtil::Rank(updates_shape)); + expected_updates_rank, updates_shape.rank()); } TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers( @@ -2985,7 +3153,7 @@ Status ValidateScatterDimensionNumbers( } int64 scatter_dims_seen = 0; - for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) { + for (int64 i = 0; i < updates_shape.rank(); ++i) { bool is_update_window_dim = absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i); if (is_update_window_dim) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 1b8fd10d691498087b28ef68517868c5def1da5a..7d39ef38e05abf0a81683c1fb0f3999908b27d23 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -176,14 +176,15 @@ class ShapeInference { // Infers the shape produced by a dynamic slice operation of size specified // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. static StatusOr InferDynamicSliceShape( - const Shape& operand_shape, const Shape& start_indices_shape, - absl::Span slice_sizes); + const Shape& operand_shape, absl::Span start_index_shapes, + absl::Span slice_sizes, bool allow_scalar_indices = true); // Infers the shape produced by a dynamic update slice operation based // on the shape of operand and update. static StatusOr InferDynamicUpdateSliceShape( const Shape& operand_shape, const Shape& update_shape, - const Shape& start_indices_shape); + absl::Span start_index_shapes, + bool allow_scalar_indices = true); // Infers the shape produced by doing a compile-time-constant indexing into // the given input shape. This is essential for operations on tuples, because diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 0a870808d4cd89fa18382522ea5a4bf2355e5ce7..26120a06b823c9fddf378991cec434a880fb888d 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -35,6 +35,7 @@ class ShapeInferenceTest : public ::testing::Test { protected: // Some handy scalar shapes. const Shape s32_ = ShapeUtil::MakeShape(S32, {}); + const Shape f16_ = ShapeUtil::MakeShape(F16, {}); const Shape f32_ = ShapeUtil::MakeShape(F32, {}); const Shape f64_ = ShapeUtil::MakeShape(F64, {}); const Shape pred_ = ShapeUtil::MakeShape(PRED, {}); @@ -260,8 +261,8 @@ TEST_F(ShapeInferenceTest, Complex) { ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok()); // Component types must match. ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok()); - // Only F32->C64 supported. - ASSERT_FALSE(complex_shape(f64_, f64_, {}).ok()); + // Only F32->C64 and F64->C128 supported. + ASSERT_FALSE(complex_shape(f16_, f16_, {}).ok()); // Validate correct uses. Shape c64_32 = ShapeUtil::MakeShape(C64, {32}); TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {})); @@ -285,6 +286,9 @@ TEST_F(ShapeInferenceTest, Complex) { ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {})); ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); + + TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f64_, f64_, {})); + ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C128, {}))); } TEST_F(ShapeInferenceTest, VariadicOpTuplify) { @@ -1006,9 +1010,9 @@ TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("Batch and contracting dimension number mismatch")); + EXPECT_TRUE(inferred_status.ok()); + EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), + ShapeUtil::MakeShape(F32, {32, 32, 64}))); } // vector vector -> scalar @@ -1100,7 +1104,6 @@ TEST_F(ShapeInferenceTest, DotGeneral) { TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); - Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(2); @@ -1114,8 +1117,28 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("Must specify one contracting dimension for both " - "lhs and rhs")); + HasSubstr("Must specify the same number of contracting " + "dimensions for lhs and rhs.")); +} + +TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 2, 14}); + Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(2); + dot_dnums.add_rhs_batch_dimensions(0); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + EXPECT_TRUE(inferred_status.ok()); + EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape)); } // BatchMatMul with different batch dimension sizes fails. @@ -1134,11 +1157,11 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) { ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("Batch dimension numbers and sizes must match")); + HasSubstr("Batch dimension sizes must match")); } -// BatchMatMul with different batch dimension numbers fails. -TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { +// BatchMatMul with different batch dimension numbers passes +TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersPasses) { Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14}); @@ -1151,9 +1174,9 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { auto inferred_status = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("Batch dimension numbers must precede non-batch")); + ASSERT_TRUE(inferred_status.ok()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), + ShapeUtil::MakeShape(F32, {2, 11, 14}))); } // BatchMatMul with out-of-range dimension numbers fails. diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 28a30b5ee2dbcb5012804578d4d037c241045309..d90dde3b13d3aa9e1de10dd9e1d11a8e6da170de 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -85,7 +85,7 @@ string ShapedBuffer::ToString() const { on_device_shape(), [this, &s](const Shape& subshape, const ShapeIndex& index) { string shape_str; - if (ShapeUtil::IsTuple(subshape)) { + if (subshape.IsTuple()) { shape_str = "tuple"; } else { shape_str = ShapeUtil::HumanStringWithLayout(subshape); diff --git a/tensorflow/compiler/xla/service/sort_simplifier.cc b/tensorflow/compiler/xla/service/sort_simplifier.cc new file mode 100644 index 0000000000000000000000000000000000000000..4a00e8d7b227f14d462ca53f695189f3f48754ee --- /dev/null +++ b/tensorflow/compiler/xla/service/sort_simplifier.cc @@ -0,0 +1,126 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/sort_simplifier.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/statusor.h" + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" + +namespace xla { +namespace { + +// If the sort instruction has a tuple shape then looks for unused output +// values and removes them from the sort instruction. Returns true if the +// graph has been modified. +StatusOr RemoveUnusedOperandFromSort(HloInstruction* sort) { + if (!sort->shape().IsTuple()) { + return false; + } + + HloComputation* computation = sort->parent(); + + if (computation->root_instruction() == sort) { + // Can't analyse users of the root instruction. + return false; + } + + // Index 0 is the sorting key used by the sort HLO itself. + absl::flat_hash_set used_indices{0}; + for (const HloInstruction* user : sort->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + // Can't analyse users other then get-tuple-element. + return false; + } + used_indices.insert(user->tuple_index()); + } + + if (used_indices.size() == sort->operand_count()) { + // All operands are used. + return false; + } + + std::vector operands{sort->mutable_operand(0)}; + std::vector new_shapes{sort->operand(0)->shape()}; + for (int64 i = 1; i < sort->operand_count(); ++i) { + if (used_indices.count(i)) { + operands.push_back(sort->mutable_operand(i)); + new_shapes.push_back(sort->operand(i)->shape()); + } + } + + Shape new_sort_shape = new_shapes.size() == 1 + ? new_shapes[0] + : ShapeUtil::MakeTupleShape(new_shapes); + HloInstruction* new_sort = computation->AddInstruction( + sort->CloneWithNewOperands(new_sort_shape, operands)); + + // Map from original get-tuple-element tuple index to new HLO instruction + absl::flat_hash_map result_map; + if (new_sort->shape().IsTuple()) { + // Old sort key maps to new sort key. + int64 new_index = 0; + for (int64 i = 0; i < sort->operand_count(); ++i) { + if (used_indices.count(i)) { + result_map[i] = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_shapes[new_index], new_sort, new_index)); + ++new_index; + } + } + } else { + result_map[0] = new_sort; + } + std::vector users(sort->users().begin(), + sort->users().end()); + for (HloInstruction* user : users) { + TF_RETURN_IF_ERROR( + user->ReplaceAllUsesWith(result_map.at(user->tuple_index()))); + TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(user)); + } + return true; +} +} // namespace + +StatusOr SortSimplifier::Run(HloModule* module) { + VLOG(2) << "HLO module before SortSimplifier:"; + XLA_VLOG_LINES(2, module->ToString()); + + bool changed = false; + std::vector sort_instrs; + for (auto* comp : module->MakeNonfusionComputations()) { + absl::c_copy_if(comp->instructions(), std::back_inserter(sort_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kSort; + }); + } + + for (HloInstruction* sort_instr : sort_instrs) { + TF_ASSIGN_OR_RETURN(bool result, RemoveUnusedOperandFromSort(sort_instr)); + changed |= result; + } + + if (changed) { + VLOG(2) << "HLO module after SortSimplifier:"; + XLA_VLOG_LINES(2, module->ToString()); + } else { + VLOG(2) << "HLO module unchanged after SortSimplifier"; + } + + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/sort_simplifier.h b/tensorflow/compiler/xla/service/sort_simplifier.h new file mode 100644 index 0000000000000000000000000000000000000000..8c6f313aa04f51e14a14450bc72fc622d74133a4 --- /dev/null +++ b/tensorflow/compiler/xla/service/sort_simplifier.h @@ -0,0 +1,35 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SORT_SIMPLIFIER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SORT_SIMPLIFIER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass which removes unused operands from sort, where an unused operand is +// defined as an operand at some index 'x' at which the output is not used. +class SortSimplifier : public HloModulePass { + public: + absl::string_view name() const override { return "simplify-sorts"; } + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SORT_SIMPLIFIER_H_ diff --git a/tensorflow/compiler/xla/service/sort_simplifier_test.cc b/tensorflow/compiler/xla/service/sort_simplifier_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cd05fcf830d32e8bac4f8b260d3dd143ab98ad7b --- /dev/null +++ b/tensorflow/compiler/xla/service/sort_simplifier_test.cc @@ -0,0 +1,102 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/sort_simplifier.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace m = match; + +using SortSimplifierTest = HloTestBase; + +TEST_F(SortSimplifierTest, RemoveUnusedSortOperandArrayResult) { + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1} + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + SortSimplifier simplifier; + uint64 num_executions = 0; + do { + num_executions++; + } while (simplifier.Run(module.get()).ValueOrDie()); + EXPECT_EQ(num_executions, 2); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Sort(m::Parameter(0)))); +} + +TEST_F(SortSimplifierTest, RemoveUnusedSortOperandTuple) { + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,87] parameter(0) + values.0 = s32[64,87] parameter(1) + values.1 = u32[64,87] parameter(2) + sort = (f32[64,87], s32[64,87], u32[64,87]) sort( + keys, values.0, values.1), + dimensions={1} + gte.0 = f32[64,87] get-tuple-element(sort), index=0 + gte.1 = u32[64,87] get-tuple-element(sort), index=2 + ROOT tuple = (f32[64,87], u32[64,87]) tuple(gte.0, gte.1) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + SortSimplifier simplifier; + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + GmockMatch(m::Tuple( + m::GetTupleElement(m::Sort(m::Parameter(0), m::Parameter(2)), 0), + m::GetTupleElement(m::Sort(m::Parameter(0), m::Parameter(2)), 1)))); +} + +TEST_F(SortSimplifierTest, DontRemoveUnusedSortKey) { + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1} + ROOT gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + SortSimplifier simplifier; + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index a21e586efadb85d18e88e44999283b28f7f65eac..15ef623cc7b2dbc31e9cba5c4783c39b8805a5aa 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -142,7 +142,7 @@ Status TransferManager::TransferArrayToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest) { const Shape on_device_shape = HostShapeToDeviceShape(literal.shape()); - TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape)) + TF_RET_CHECK(on_device_shape.IsArray()) << "On-device representation of " << ShapeUtil::HumanString(literal.shape()) << " is not an array: " << ShapeUtil::HumanString(on_device_shape); @@ -227,7 +227,7 @@ Status TransferManager::WriteTupleIndexTablesAsync( return ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_device_shape(), [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { - if (ShapeUtil::IsTuple(device_subshape)) { + if (device_subshape.IsTuple()) { se::DeviceMemoryBase device_memory = device_buffer.buffer(index); TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == device_memory.size()); @@ -248,6 +248,22 @@ Status TransferManager::WriteTupleIndexTablesAsync( }); } +Status TransferManager::WriteRootTupleIndexTable( + se::Stream* stream, const ShapedBuffer& device_buffer) { + TF_RET_CHECK(device_buffer.on_device_shape().IsTuple()); + se::DeviceMemoryBase device_memory = device_buffer.buffer({}); + TF_RET_CHECK(GetByteSizeRequirement(device_buffer.on_device_shape()) == + device_memory.size()); + + std::vector elements; + for (int64 i = 0; + i < ShapeUtil::TupleElementCount(device_buffer.on_device_shape()); ++i) { + elements.push_back(device_buffer.buffer({i})); + } + return WriteSingleTupleIndexTable( + stream, elements, device_buffer.on_device_shape(), &device_memory); +} + Status TransferManager::TransferBufferFromDevice( se::Stream* stream, const se::DeviceMemoryBase& source, int64 size, void* destination) { diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 49f0b8f8b72001f07200d3e94828f60fcb0fa8fb..43a50487c636da75224547286a31625db3f91330 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -146,6 +146,12 @@ class TransferManager { Status WriteTupleIndexTablesAsync(se::Stream* stream, const ShapedBuffer& device_buffer); + // Writes a tuple index buffer for the root of 'device_buffer', which must + // be a tuple. Unlike WriteTupleIndexTables, only writes the root buffer, + // rather than writing all subbuffers. This method is always asynchronous. + Status WriteRootTupleIndexTable(se::Stream* stream, + const ShapedBuffer& device_buffer); + // Determines the byte size requirement for the given shape on the underlying // architecture. This will be used to allocate an appropriately sized memory // region for a host-to-device transfer. diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index eaf4f28b87ce7706832eebb0bc02d015e64ee89a..a95ca2bf2a8fcd700eb9234cafbfce9b62f2370c 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -45,7 +45,7 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( auto& operand = *dot.operand(i); if (operand.IsRank2Transpose()) { operand_set.push_back(i); - } else if (ShapeUtil::Rank(operand.shape()) != 2) { + } else if (operand.shape().rank() != 2) { return {}; } } @@ -130,8 +130,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { HloInstruction* new_lhs; const int64 kLhsIdx = 0; - if (std::find(operand_indices.begin(), operand_indices.end(), kLhsIdx) != - operand_indices.end()) { + if (absl::c_linear_search(operand_indices, kLhsIdx)) { HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx); const auto& transpose_dimensions = transpose.dimensions(); HloInstruction& transpose_operand = *transpose.mutable_operand(0); @@ -154,8 +153,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { HloInstruction* new_rhs; const int64 kRhsIdx = 1; - if (std::find(operand_indices.begin(), operand_indices.end(), kRhsIdx) != - operand_indices.end()) { + if (absl::c_linear_search(operand_indices, kRhsIdx)) { HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx); const auto& transpose_dimensions = transpose.dimensions(); HloInstruction& transpose_operand = *transpose.mutable_operand(0); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 50d51eaeb762e208004c1dae3dcc27503f3f94e9..5e505aaf02f157d0cba9dff42b1a9b89a6691504 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -55,11 +56,10 @@ bool PointsToSet::IsAmbiguous() const { bool PointsToSet::IsDistinct() const { bool distinct = true; - std::set all_points_to; - ForEachElement([&distinct, &all_points_to](const ShapeIndex& /*index*/, - const BufferList& points_to) { + absl::flat_hash_set all_points_to; + ForEachElement([&](const ShapeIndex& /*index*/, const BufferList& points_to) { for (auto& buffer : points_to) { - if (all_points_to.count(buffer) != 0) { + if (all_points_to.contains(buffer)) { distinct = false; } all_points_to.insert(buffer); @@ -87,9 +87,7 @@ bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const { bool found = false; ForEachElement([&found, &buffer](const ShapeIndex& /*index*/, const BufferList& pointed_to_buffers) { - if (!found && - std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(), - &buffer) != pointed_to_buffers.end()) { + if (!found && absl::c_linear_search(pointed_to_buffers, &buffer)) { found = true; } }); @@ -99,8 +97,7 @@ bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const { bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer, const ShapeIndex& index) const { const auto& pointed_to_buffers = element(index); - return std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(), - &buffer) != pointed_to_buffers.end(); + return absl::c_linear_search(pointed_to_buffers, &buffer); } void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer, @@ -210,7 +207,7 @@ Status TuplePointsToAnalysis::DefaultAction(HloInstruction* hlo_instruction) { &logical_buffer_analysis_->GetBuffer(hlo_instruction, index)); }); - if (ShapeUtil::IsTuple(hlo_instruction->shape())) { + if (hlo_instruction->shape().IsTuple()) { // If the hlo instruction is a tuple-shaped, then trivially the instruction // itself is the source of the tuple. points_to_set.add_tuple_source({}, hlo_instruction); @@ -604,9 +601,8 @@ bool TuplePointsToAnalysis::DoesNotUseOperandBuffer( } else if (user->opcode() == HloOpcode::kFusion && user->fusion_kind() == HloInstruction::FusionKind::kLoop) { // Find fusion parameter associated with 'operand'. - auto it = std::find_if( - user->fused_parameters().begin(), user->fused_parameters().end(), - [=](HloInstruction* fused_param) { + auto it = absl::c_find_if( + user->fused_parameters(), [&](HloInstruction* fused_param) { return user->operand(fused_param->parameter_number()) == operand; }); CHECK(it != user->fused_parameters().end()); @@ -672,9 +668,8 @@ bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt( } // Find fusion parameter associated with 'operand'. const auto& fused_params = fusion->fused_parameters(); - auto fused_param_it = std::find_if( - fused_params.begin(), fused_params.end(), - [&](HloInstruction* fused_param) { + auto fused_param_it = + absl::c_find_if(fused_params, [&](HloInstruction* fused_param) { return fusion->operand(fused_param->parameter_number()) == operand; }); if (fused_param_it == fused_params.end()) { @@ -743,11 +738,10 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( // Check if one operand of kAdd fused root is kDot or kConvolution. auto* add = user->fused_expression_root(); auto add_operand_it = - std::find_if(add->operands().begin(), add->operands().end(), - [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot; - }); + absl::c_find_if(add->operands(), [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot; + }); if (add_operand_it == add->operands().end()) { return false; } diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 561762b5d424ed5f537665be9d67a81dc8bdd56e..fd5759e44230db8223822d6ae0f511027f73d8f9 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -623,7 +623,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { void Run(const bool add_additional_gte0_user) { Shape input_shape = ShapeUtil::MakeShape(F32, {8}); Shape update_shape = ShapeUtil::MakeShape(F32, {3}); - Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {}); Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, update_shape, starts_shape}); @@ -657,7 +657,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { HloInstruction::CreateGetTupleElement(starts_shape, tuple_param0, 2)); // Update 'input' with 'update' at dynamic 'starts' indices. builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - input_shape, input, update, starts)); + input_shape, input, update, {starts})); // Build computation and add it to module as entry computation. BuildModule(builder.Build()); @@ -721,9 +721,8 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { // to fusion 'operand'. HloInstruction* GetFusionParameterForOperand(HloInstruction* fusion, HloInstruction* operand) { - auto it = std::find_if( - fusion->fused_instructions().begin(), - fusion->fused_instructions().end(), [=](const HloInstruction* fused) { + auto it = absl::c_find_if( + fusion->fused_instructions(), [&](const HloInstruction* fused) { return fused->opcode() == HloOpcode::kParameter && fusion->operand(fused->parameter_number()) == operand; }); @@ -734,7 +733,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { // Returns all users of 'fusion_paran' at 'tuple_index'. std::vector GetFusionParameterUsersAt( HloInstruction* fusion_param, int64 tuple_index) { - CHECK(ShapeUtil::IsTuple(fusion_param->shape())); + CHECK(fusion_param->shape().IsTuple()); std::vector users_at_tuple_index; for (auto user : fusion_param->users()) { CHECK_EQ(HloOpcode::kGetTupleElement, user->opcode()); @@ -883,12 +882,12 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, {starts})); builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -977,12 +976,12 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, {starts})); builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -1004,7 +1003,7 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { Shape data_shape = ShapeUtil::MakeShape(F32, {8}); Shape update_shape = ShapeUtil::MakeShape(F32, {4}); - Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {}); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); auto update = builder.AddInstruction( @@ -1012,7 +1011,7 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { auto starts = builder.AddInstruction( HloInstruction::CreateParameter(2, starts_shape, "starts")); auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, data, update, starts)); + data_shape, data, update, {starts})); BuildModuleAndRunAnalysis(builder.Build()); diff --git a/tensorflow/compiler/xla/service/tuple_util.cc b/tensorflow/compiler/xla/service/tuple_util.cc index cfb0c787d09557fd1aec3517eb9698cfec323369..90ea79ec263a038556ccbd2cd345b337c5a5dcf3 100644 --- a/tensorflow/compiler/xla/service/tuple_util.cc +++ b/tensorflow/compiler/xla/service/tuple_util.cc @@ -21,7 +21,7 @@ namespace xla { /*static*/ HloInstruction* TupleUtil::ExtractPrefix(HloInstruction* input_tuple, int64 elements) { - CHECK(ShapeUtil::IsTuple(input_tuple->shape())); + CHECK(input_tuple->shape().IsTuple()); HloComputation* computation = input_tuple->parent(); const Shape& input_shape = input_tuple->shape(); @@ -41,7 +41,7 @@ namespace xla { /*static*/ HloInstruction* TupleUtil::AppendSuffix( HloInstruction* input_tuple, absl::Span trailing_values) { - CHECK(ShapeUtil::IsTuple(input_tuple->shape())); + CHECK(input_tuple->shape().IsTuple()); HloComputation* computation = input_tuple->parent(); const Shape& input_shape = input_tuple->shape(); diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index 68e2569f66bea9ec1223e454d1ead0efc7b9498e..c93a9ba3176002a34fe84a29e62075de4d19168f 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -301,7 +301,7 @@ optional ComputeWhileLoopTripCountUpperBound(HloInstruction* while_op) { /*dest_shape_index=*/{indvar_index}, /*src_shape_index=*/{})); StatusOr eval_result = - evaluator.Evaluate(*while_cond, {std::move(fake_input)}); + evaluator.Evaluate(*while_cond, {std::move(fake_input)}); if (!eval_result.ok()) { VLOG(2) << "Couldn't evaluate while loop condition."; diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 41011176ffa91e885bc58364d1fb19617d3518ad..69cc8feb3f31ad782b9d3437d81d0ab8ce10aadb 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -89,7 +89,7 @@ static void CreateLoopInvariantCopy( HloInstruction* next_operand = frame->instruction->mutable_operand(frame->operand_index++); - if (hoisted_instructions->count(next_operand) || + if (hoisted_instructions->contains(next_operand) || next_operand == while_body_param) { continue; } @@ -127,7 +127,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( HloInstruction* while_instr) { auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); - if (!ShapeUtil::IsTuple(while_instr->shape())) { + if (!while_instr->shape().IsTuple()) { // This restriction leaves one interesting pattern on the table: // // while_body(f32[1024, 1024] %param) { @@ -168,7 +168,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( // is no benefit to hoisting them unless something that uses it is also // hoisted. for (auto* instr : WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) { - if (ShapeUtil::IsArray(instr->shape())) { + if (instr->shape().IsArray()) { // TODO(b/79147885): We should try to generalize this to tuples for // uniformity's sake, if nothing else. InsertOrDie(&unhoisted_invariant_instructions, instr); @@ -221,7 +221,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( ShapeUtil::ForEachSubshape( operand->shape(), [&input_size](const Shape& subshape, const ShapeIndex& /*index*/) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { input_size += ShapeUtil::ByteSizeOfElements(subshape); } }); @@ -229,7 +229,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( ShapeUtil::ForEachSubshape( instruction->shape(), [&output_size](const Shape& subshape, const ShapeIndex& /*index*/) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { output_size += ShapeUtil::ByteSizeOfElements(subshape); } }); @@ -241,7 +241,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( auto is_invariant = [&](HloInstruction* op) { return hoisted_instructions.find(op) != hoisted_instructions.end() || - unhoisted_invariant_instructions.count(op) || + unhoisted_invariant_instructions.contains(op) || op->opcode() == HloOpcode::kConstant; }; diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 8e7c4bc8828552e197b41f874c070d496b85a382..3587c016b4420163a607422b1acc838646fab83a 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -299,7 +299,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { // bitcast either. auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); + auto effective_scalar_s32 = ShapeUtil::MakeShape(S32, {1}); auto token_shape = ShapeUtil::MakeTokenShape(); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape}); @@ -314,10 +314,12 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); HloInstruction* in_token = builder.AddInstruction( HloInstruction::CreateGetTupleElement(token_shape, param, 2)); - HloInstruction* bitcast_inst = builder.AddInstruction( - HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); - HloInstruction* out_token = builder.AddInstruction( - HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, in_token, "")); + HloInstruction* bitcast_inst = + builder.AddInstruction(HloInstruction::CreateUnary( + effective_scalar_s32, HloOpcode::kBitcast, gte_0)); + HloInstruction* out_token = + builder.AddInstruction(HloInstruction::CreateOutfeed( + effective_scalar_s32, bitcast_inst, in_token, "")); builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, out_token})); @@ -352,9 +354,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { // The bitcast's user can be hoisted, so hoist the bitcast too. auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); - Shape while_shape = - ShapeUtil::MakeTupleShape({scalar_s32, scalar_f32, scalar_f32}); + auto effective_scalar_s32 = ShapeUtil::MakeShape(S32, {1}); + Shape while_shape = ShapeUtil::MakeTupleShape( + {scalar_s32, effective_scalar_s32, effective_scalar_s32}); HloComputation* while_body = [&]() { HloComputation::Builder builder(TestName() + ".while_body"); @@ -363,12 +365,13 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { HloInstruction* gte_0 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); HloInstruction* gte_1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_f32, param, 1)); - HloInstruction* bitcast_inst = builder.AddInstruction( - HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); + HloInstruction::CreateGetTupleElement(effective_scalar_s32, param, 1)); + HloInstruction* bitcast_inst = + builder.AddInstruction(HloInstruction::CreateUnary( + effective_scalar_s32, HloOpcode::kBitcast, gte_0)); HloInstruction* add_inst = builder.AddInstruction(HloInstruction::CreateBinary( - scalar_f32, HloOpcode::kAdd, bitcast_inst, gte_1)); + effective_scalar_s32, HloOpcode::kAdd, bitcast_inst, gte_1)); builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, add_inst})); diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index d30f67dd8110b88166fe807762fb653190ec00bc..386ffb995477ff1b4aef73080b6a6fd988dd1980 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -58,7 +58,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { HloComputation* while_body = while_op->while_body(); HloInstruction* while_body_root = while_body->root_instruction(); - if (!ShapeUtil::IsTuple(while_init->shape())) { + if (!while_init->shape().IsTuple()) { VLOG(2) << "While op's carried value isn't tuple shaped."; return false; } @@ -109,8 +109,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // operand appears in, but it may appear more than once! if (user->user_count() == 1 && user->users().front() == while_body_root && while_body_root->operand_index(user) == user->tuple_index() && - std::count(while_body_root->operands().begin(), - while_body_root->operands().end(), user) == 1) { + absl::c_count(while_body_root->operands(), user) == 1) { continue; } @@ -127,7 +126,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // through to the while body's root, count that element as "used", since // removing that element would be observable. for (int64 i = 0; i < while_body_root->operand_count(); ++i) { - if (used_tuple_indices.count(i)) { + if (used_tuple_indices.contains(i)) { continue; } @@ -158,7 +157,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // Build up maps from the old/new to the new/old tuple indices. std::vector new_to_old_tuple_idx(used_tuple_indices.begin(), used_tuple_indices.end()); - std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end()); + absl::c_sort(new_to_old_tuple_idx); absl::flat_hash_map old_to_new_tuple_idx; for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) { @@ -181,7 +180,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // replace the old instructions after we remove unused elements from the while // tuple. auto make_while_computation_replacements = [&](const HloComputation* comp) { - std::unordered_map> + absl::flat_hash_map> replacements; auto* param = comp->parameter_instruction(0); @@ -233,7 +232,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { while_cond->CloneWithReplacements( make_while_computation_replacements(while_cond)); - std::unordered_map> + absl::flat_hash_map> while_body_replacements = make_while_computation_replacements(while_body); std::vector new_while_body_root_elems; new_while_body_root_elems.reserve(new_to_old_tuple_idx.size()); @@ -583,8 +582,7 @@ static StatusOr TryPropagateConstant(HloInstruction* while_op) { static std::unique_ptr UnflattenTupleInstr( absl::Span instrs, const Shape& desired_shape, std::vector>* new_instrs) { - CHECK(ShapeUtil::IsTuple(desired_shape)) - << ShapeUtil::HumanString(desired_shape); + CHECK(desired_shape.IsTuple()) << ShapeUtil::HumanString(desired_shape); // For each child shape in `desired_shape`, slice out the correct number of // `instrs` and call UnflattenTupleInstr recursively. At each step we remove @@ -593,7 +591,7 @@ static std::unique_ptr UnflattenTupleInstr( std::vector elems; for (int64 i = 0; i < desired_shape.tuple_shapes_size(); ++i) { const Shape& subshape = desired_shape.tuple_shapes(i); - if (!ShapeUtil::IsTuple(subshape)) { + if (!subshape.IsTuple()) { elems.push_back(instrs[0]); instrs.remove_prefix(1); continue; @@ -603,7 +601,7 @@ static std::unique_ptr UnflattenTupleInstr( int64 num_leaves = 0; ShapeUtil::ForEachSubshape( subshape, [&](const Shape& s, const ShapeIndex& /*index*/) { - if (!ShapeUtil::IsTuple(s)) { + if (!s.IsTuple()) { ++num_leaves; } }); @@ -625,7 +623,7 @@ static std::vector GetFlatTupleElems( HloInstruction* instr, std::vector>* new_instrs) { const auto& shape = instr->shape(); - if (!ShapeUtil::IsTuple(shape)) { + if (!shape.IsTuple()) { return {instr}; } std::vector elems; @@ -665,7 +663,7 @@ static StatusOr TryFlattenNestedTuples(HloInstruction* while_op) { std::vector flattened_shape_elems; ShapeUtil::ForEachSubshape(while_shape, [&](const Shape& s, const ShapeIndex& /*index*/) { - if (!ShapeUtil::IsTuple(s)) { + if (!s.IsTuple()) { flattened_shape_elems.push_back(s); } }); diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 3713989ca2f64ee1d94c9f77255017909d957da2..ecca76b1e86d833c73fbb9bad6a341660a7d2669 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -407,13 +407,12 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) { // The original while instruction is still left in the module as a dead // instruction, find a while instruction with a different name as the new // while instruction. + const auto& instrs = m->entry_computation()->instructions(); HloInstruction* new_while_op = - *std::find_if(m->entry_computation()->instructions().begin(), - m->entry_computation()->instructions().end(), - [&](const HloInstruction* instr) { - return (instr->opcode() == HloOpcode::kWhile && - instr->name() != "while"); - }); + *absl::c_find_if(instrs, [&](const HloInstruction* instr) { + return (instr->opcode() == HloOpcode::kWhile && + instr->name() != "while"); + }); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); EXPECT_TRUE( diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index 039ccda7322f5efda6a827efbeda1225c3596cc0..d77386497a14b3e52be2ea7f655fa330f60e4a97 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -97,7 +97,7 @@ WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) { WhileUtil::MakeInstructionsLiveIn( HloInstruction* while_instr, absl::Span instructions) { - CHECK(ShapeUtil::IsTuple(while_instr->shape())); + CHECK(while_instr->shape().IsTuple()); int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size(); Shape new_while_shape = while_instr->shape(); diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc index 83d696fe0915086c3c98b6d7cbdaeaeb4d9d0bdb..661b7aa7d99ca549da6a509812760a1665d60919 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc @@ -31,16 +31,21 @@ StatusOr ZeroSizedHloElimination::Run(HloModule* module) { bool changed = false; for (HloComputation* comp : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { - if (instruction->HasSideEffect() || - !ShapeUtil::IsArray(instruction->shape()) || + if (instruction->HasSideEffect() || !instruction->shape().IsArray() || instruction->opcode() == HloOpcode::kConstant) { continue; } if (comp->IsRemovable(instruction) && ShapeUtil::IsZeroElementArray(instruction->shape())) { + // If the instruction doesn't have a layout, use a default layout for + // the literal. + Shape shape = instruction->shape(); + if (!LayoutUtil::HasLayout(shape)) { + LayoutUtil::SetToDefaultLayout(&shape); + } TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction( - instruction, HloInstruction::CreateConstant( - Literal::CreateFromShape(instruction->shape())))); + instruction, + HloInstruction::CreateConstant(Literal::CreateFromShape(shape)))); changed = true; } } diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc index a546a6d39cc55d1f327b8449c7d26cd4c95dbf98..572a79609e7a912277af0fd2ba43f9a1e14a6f52 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -82,5 +82,18 @@ TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateConstant) { EXPECT_FALSE(changed); } +TEST_F(ZeroSizedHloEliminationTest, ZeroSizedInstructionWithoutLayoutFolded) { + Shape op_shape = ShapeUtil::MakeShape(F32, {4, 0}); + op_shape.clear_layout(); + HloInstruction* param1 = builder_.AddInstruction( + HloInstruction::CreateParameter(1, op_shape, "zero sized param 1")); + HloInstruction* param2 = builder_.AddInstruction( + HloInstruction::CreateParameter(2, op_shape, "zero sized param 2")); + builder_.AddInstruction( + HloInstruction::CreateBinary(op_shape, HloOpcode::kAdd, param1, param2)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination()); + EXPECT_TRUE(changed); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc index b206345db2ac2940b1f139c82fa03a93538b5ccd..a36d3547a0987422c2658b0f3046f7b1f83369c6 100644 --- a/tensorflow/compiler/xla/shape.cc +++ b/tensorflow/compiler/xla/shape.cc @@ -27,6 +27,21 @@ Shape::Shape(const ShapeProto& shape_proto) { for (const int64 dimension : shape_proto.dimensions()) { add_dimensions(dimension); } + // A malformed proto may have different is_dynamic_dimension_size and + // dimensions_size. Since C++ is evil, and we have no good way of bailing out + // in a constructor, conservatively trim the is_dynamic_dimension size. + // TODO(b/120111794): Make this a hard error when we have a factory method + // instead of a constructor. + if (shape_proto.dimensions_size() != + shape_proto.is_dynamic_dimension_size()) { + LOG(ERROR) << "Malformed shape proto: number of is_dynamic_dimension " + "fields does not match number of dimension fields"; + } + int64 num_dynamic_dimension_fields = std::min( + shape_proto.dimensions_size(), shape_proto.is_dynamic_dimension_size()); + for (int i = 0; i < num_dynamic_dimension_fields; i++) { + dynamic_dimensions_[i] = shape_proto.is_dynamic_dimension(i); + } tuple_shapes_.reserve(shape_proto.tuple_shapes_size()); for (const ShapeProto& element_shape : shape_proto.tuple_shapes()) { *add_tuple_shapes() = Shape(element_shape); @@ -43,6 +58,9 @@ ShapeProto Shape::ToProto() const { for (const int64 dimension : dimensions()) { proto.add_dimensions(dimension); } + for (const bool dynamic : dynamic_dimensions_) { + proto.add_is_dynamic_dimension(dynamic); + } proto.mutable_tuple_shapes()->Reserve(tuple_shapes_size()); for (const Shape& shape : tuple_shapes()) { *proto.add_tuple_shapes() = shape.ToProto(); @@ -61,6 +79,112 @@ string Shape::ToString(bool print_layout) const { } } +bool Shape::is_static() const { + if (IsTuple()) { + for (const Shape& subshape : tuple_shapes_) { + if (!subshape.is_static()) { + return false; + } + } + } + return !absl::c_any_of(dynamic_dimensions_, [](bool b) { return b; }); +} + +void Shape::DeleteDimension(int64 dim_to_delete) { + CHECK(IsArray()); + CHECK_GE(dim_to_delete, 0); + CHECK_LT(dim_to_delete, dimensions_.size()); + dimensions_.erase(dimensions_.begin() + dim_to_delete); + dynamic_dimensions_.erase(dynamic_dimensions_.begin() + dim_to_delete); + if (LayoutUtil::HasLayout(*this)) { + layout_.set_format(DENSE); + for (int64 i = 0; i < layout_.minor_to_major().size();) { + if (layout_.minor_to_major(i) == dim_to_delete) { + layout_.mutable_minor_to_major()->erase( + layout_.mutable_minor_to_major()->begin() + i); + continue; + } + if (layout_.minor_to_major(i) > dim_to_delete) { + (*layout_.mutable_minor_to_major())[i] -= 1; + } + ++i; + } + } +} + +bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) { + if (lhs.IsTuple()) { + return rhs.IsTuple() && + absl::c_equal( + lhs.tuple_shapes(), rhs.tuple_shapes(), + [=](const Shape& l, const Shape& r) { return (*this)(l, r); }); + } else if (!lhs.IsArray()) { + // Non-tuple, non-array tupes such as opaque and token types are trivially + // the same. + return lhs.element_type() == rhs.element_type(); + } + + if (!rhs.IsArray()) { + return false; + } + + if (!ignore_element_type_) { + if ((ignore_fp_precision_ && + !ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) || + (!ignore_fp_precision_ && !ShapeUtil::SameElementType(lhs, rhs))) { + VLOG(3) << "CompareShapes: lhs element type != rhs element type"; + return false; + } + } + + if (!ignore_layout_) { + if (lhs.layout().format() != rhs.layout().format()) { + VLOG(3) << "CompareShapes: lhs layout format != rhs layout format"; + return false; + } + if (LayoutUtil::IsDenseArray(lhs)) { + if (!absl::c_equal(LayoutUtil::MinorToMajor(lhs), + LayoutUtil::MinorToMajor(rhs))) { + VLOG(3) << "CompareShapes: lhs layout != rhs layout"; + return false; + } + + const auto& lhs_tiles = lhs.layout().tiles(); + const auto& rhs_tiles = rhs.layout().tiles(); + if (lhs_tiles.size() != rhs_tiles.size()) { + return false; + } + for (int64 i = 0; i < lhs_tiles.size(); i++) { + if (!absl::c_equal(lhs_tiles[i].dimensions(), + rhs_tiles[i].dimensions())) { + return false; + } + } + + if (lhs.layout().element_size_in_bits() != + rhs.layout().element_size_in_bits()) { + return false; + } + } + } + + if (!ShapeUtil::SameDimensions(lhs, rhs)) { + VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; + return false; + } + + if (!ignore_dynamic_dimension_) { + for (int i = 0; i < lhs.rank(); ++i) { + if (lhs.is_dynamic_dimension(i) != rhs.is_dynamic_dimension(i)) { + VLOG(3) + << "CompareShapes: lhs and rhs have different dynamic dimensions."; + return false; + } + } + } + return true; +} + std::ostream& operator<<(std::ostream& out, const Shape& shape) { out << shape.ToString(/*print_layout=*/true); return out; diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index 7643f64d8a5f0450be1cddad35cf7422afb89048..e6b4e872f69e16ea407dc18cadfc83618080084f 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/xla/layout.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" @@ -44,6 +45,43 @@ class Shape { // without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]". string ToString(bool print_layout = false) const; + // Returns the rank (number of dimensions) of the given shape. Shape must be + // an array. + int64 rank() const { + CHECK(IsArray()) << "Non-arrays do not have a rank, shape: " << ToString(); + return dimensions_.size(); + } + + // Returns whether the shape is of the specified type (array, tuple, etc). + bool IsArray() const { return primitive_util::IsArrayType(element_type()); } + bool IsTuple() const { return element_type() == TUPLE; } + bool IsToken() const { return element_type() == TOKEN; } + bool IsOpaque() const { return element_type() == OPAQUE; } + + // Returns true if no array dimension in the shape is dynamically sized. Tuple + // shapes are traversed recursively. + bool is_static() const; + + // Returns true if the given dimension is dynamically-sized. + bool is_dynamic_dimension(int dimension) const { + return dynamic_dimensions_.at(dimension); + } + + // Sets whether or not the given dimension is dynamically-sized. + void set_dynamic_dimension(int dimension, bool is_dynamic) { + dynamic_dimensions_[dimension] = is_dynamic; + } + + const std::vector& dynamic_dimensions() const { + return dynamic_dimensions_; + } + + // Add dimension_upper_bound(). + + // Removes the given dimension form the shape. Layout, if it exists, is + // adjusted to match the modified shape. + void DeleteDimension(int64 dim_to_delete); + // The following methods mirror the protobuf generated code interface for the // message ShapeProto. This enabled easy migration of this data structure // from a proto to a proper C++ class. @@ -58,10 +96,16 @@ class Shape { int dimensions_size() const { return dimensions_.size(); } int64 dimensions(int index) const { return dimensions_.at(index); } void set_dimensions(int index, int64 value) { dimensions_.at(index) = value; } - void add_dimensions(int64 value) { dimensions_.push_back(value); } - void clear_dimensions() { dimensions_.clear(); } + void add_dimensions(int64 value) { + dimensions_.push_back(value); + dynamic_dimensions_.push_back(false); + } + void clear_dimensions() { + dimensions_.clear(); + dynamic_dimensions_.clear(); + } const std::vector& dimensions() const { return dimensions_; } - std::vector* mutable_dimensions() { return &dimensions_; } + absl::Span mutable_dimensions() { return absl::MakeSpan(dimensions_); } // Methods for accessing the tuple subshapes. This field only non-empty for // tuple shapes. @@ -98,13 +142,58 @@ class Shape { string ShortDebugString() const { return ToProto().ShortDebugString(); } string DebugString() const { return ToProto().DebugString(); } - public: + // Equal is a configurable functor to check the equality of two shapes. + // + // Examples: + // + // - Comparing two shapes ignoring they layout difference: + // Equal().IgnoreLayout()(shape1, shape2); + // + // - Comparing two shapes ignoring they layout and element type difference: + // Equal().IgnoreLayout().IgnoreElementType()(shape1, shape2); + class Equal { + public: + Equal() = default; + + bool operator()(const Shape& lhs, const Shape& rhs); + + Equal& IgnoreLayout() { + ignore_layout_ = true; + return *this; + } + Equal& IgnoreElementType() { + ignore_element_type_ = true; + return *this; + } + Equal& IgnoreFpPrecision() { + ignore_fp_precision_ = true; + return *this; + } + Equal& IgnoreDynamicDimension() { + ignore_dynamic_dimension_ = true; + return *this; + } + + public: + bool ignore_layout_ = false; + bool ignore_element_type_ = false; + bool ignore_fp_precision_ = false; + bool ignore_dynamic_dimension_ = false; + }; + + private: // The element type of this shape (tuple, array, etc). PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID; - // The array bounds of the dimensions. This is nonempty only for array shapes. + // The array bounds of the dimensions. This is nonempty only for array + // shapes. For a dynamically-sized dimension, the respective value in this + // vector is an inclusive upper limit of the array bound. std::vector dimensions_; + // This vector is the same size as 'dimensions_' and indicates whether the + // respective dimension is dynamically sized. + std::vector dynamic_dimensions_; + // The tuple element subshapes. This is nonempty only for tuple shapes. std::vector tuple_shapes_; diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index d44db89d571891ecef554cd45c050017833982bb..a000886d60d06a4a598910c901accb6dfd0a8f1a 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -52,7 +52,7 @@ bool ShapeLayout::MatchesLayoutInShape(const Shape& shape) const { const Layout& ShapeLayout::layout() const { CHECK(LayoutIsSet()); - CHECK(!ShapeUtil::IsTuple(shape_)); + CHECK(!shape_.IsTuple()); return shape_.layout(); } @@ -61,15 +61,15 @@ void ShapeLayout::Clear() { LayoutUtil::ClearLayout(&shape_); } bool ShapeLayout::LayoutIsSet() const { return LayoutUtil::HasLayout(shape_); } void ShapeLayout::ResetLayout(const Layout& layout) { - CHECK(!ShapeUtil::IsTuple(shape_)); - CHECK(!ShapeUtil::IsOpaque(shape_)); + CHECK(!shape_.IsTuple()); + CHECK(!shape_.IsOpaque()); *shape_.mutable_layout() = layout; TF_CHECK_OK(ShapeUtil::ValidateShape(shape_)); } void ShapeLayout::ResetLayout(const Layout& layout, ShapeIndexView shape_index) { - CHECK(ShapeUtil::IsTuple(shape_)); + CHECK(shape_.IsTuple()); *ShapeUtil::GetMutableSubshape(&shape_, shape_index)->mutable_layout() = layout; TF_CHECK_OK(ShapeUtil::ValidateShape(shape_)); diff --git a/tensorflow/compiler/xla/shape_test.cc b/tensorflow/compiler/xla/shape_test.cc index e396897eeebc2e7bdc2dc49300c8906710608b05..55ce5fe884e98e474253be9ef694f1b8137b4b01 100644 --- a/tensorflow/compiler/xla/shape_test.cc +++ b/tensorflow/compiler/xla/shape_test.cc @@ -41,11 +41,13 @@ class ShapeTest : public ::testing::Test { ShapeUtil::MakeTupleShape({opaque_, scalar_, matrix_, matrix2_}); const Shape nested_tuple_ = ShapeUtil::MakeTupleShape({tuple_, matrix_, token_}); + const Shape dyanmic_matrix_ = + ShapeUtil::MakeShape(S32, {5, 2}, {true, false}); }; TEST_F(ShapeTest, ShapeToFromProto) { - for (const Shape& shape : - {opaque_, token_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_}) { + for (const Shape& shape : {opaque_, token_, scalar_, matrix_, matrix2_, + tuple_, nested_tuple_, dyanmic_matrix_}) { Shape shape_copy(shape.ToProto()); EXPECT_TRUE(ShapeUtil::Equal(shape, shape_copy)) << shape << " != " << shape_copy; @@ -74,6 +76,47 @@ TEST_F(ShapeTest, ShapeToString) { nested_tuple_.ToString(/*print_layout=*/true)); } +TEST_F(ShapeTest, DynamicShapeToString) { + Shape array_shape = + ShapeUtil::MakeShape(F32, {23, 44, 55}, {true, false, true}); + EXPECT_EQ("f32[<=23,44,<=55]", array_shape.ToString()); + + array_shape.set_dynamic_dimension(2, false); + EXPECT_EQ("f32[<=23,44,55]", array_shape.ToString()); +} + +TEST_F(ShapeTest, IsStatic) { + EXPECT_TRUE(opaque_.is_static()); + EXPECT_TRUE(token_.is_static()); + EXPECT_TRUE(matrix_.is_static()); + EXPECT_TRUE(tuple_.is_static()); + EXPECT_TRUE(nested_tuple_.is_static()); + + Shape dynamic_matrix = matrix_; + EXPECT_TRUE(dynamic_matrix.is_static()); + dynamic_matrix.set_dynamic_dimension(1, true); + EXPECT_FALSE(dynamic_matrix.is_static()); + + Shape dynamic_tuple = tuple_; + EXPECT_TRUE(dynamic_tuple.is_static()); + ShapeUtil::GetMutableSubshape(&dynamic_tuple, {2}) + ->set_dynamic_dimension(1, true); + EXPECT_FALSE(dynamic_tuple.is_static()); +} + +TEST_F(ShapeTest, IsDynamicDimension) { + Shape dynamic_matrix = matrix_; + dynamic_matrix.set_dynamic_dimension(1, true); + EXPECT_FALSE(dynamic_matrix.is_dynamic_dimension(0)); + EXPECT_TRUE(dynamic_matrix.is_dynamic_dimension(1)); + + Shape dynamic_tuple = tuple_; + EXPECT_TRUE(dynamic_tuple.is_static()); + ShapeUtil::GetMutableSubshape(&dynamic_tuple, {2}) + ->set_dynamic_dimension(1, true); + EXPECT_FALSE(dynamic_tuple.is_static()); +} + TEST_F(ShapeTest, ProgramShapeToFromProto) { ProgramShape program_shape; *program_shape.add_parameters() = ShapeUtil::MakeShape(F32, {1, 2, 3}); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 7bf97729165bef98fabc29040e02203eee68a53c..089120179e2a77518eb5b18c11a35670b03e9b77 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -395,7 +395,7 @@ class ShapeTreeIterator template int64 ShapeTree::CountSubshapes(const Shape& shape) { int64 current_count = 1; - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { int64 count = ShapeUtil::TupleElementCount(shape); for (int i = 0; i < count; ++i) { current_count += CountSubshapes(shape.tuple_shapes(i)); @@ -407,7 +407,7 @@ int64 ShapeTree::CountSubshapes(const Shape& shape) { template void ShapeTree::InitChildren(const Shape& shape, const T& init_value, Node* node, Index* index) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { const int64 size = ShapeUtil::TupleElementCount(shape); #ifndef NDEBUG index->children_count = size; @@ -443,7 +443,7 @@ void ShapeTree::InitChildren(const Shape& shape, const T& init_value, template void ShapeTree::InitChildren(const Shape& shape, Node* node, Index* index) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { const int64 size = ShapeUtil::TupleElementCount(shape); #ifndef NDEBUG index->children_count = size; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index be7d71ada009535a5c08aa3d3d062fa451cfeef3..1ada4bc0362f86bc770d4adfcd4d4b0ff7379c77 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -81,73 +81,10 @@ bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const { /* static */ bool ShapeUtil::IsArrayPrimitiveType( PrimitiveType primitive_type) { - return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && - primitive_type != OPAQUE && primitive_type != TOKEN; + return primitive_util::IsArrayType(primitive_type); } namespace { - -// Recursive helper for comparing the equality of two shapes. Returns true if -// the shapes are the same. If compare_layouts is true, then layouts must also -// match. -bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, - bool ignore_fp_precision) { - if ((ignore_fp_precision && - !ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) || - (!ignore_fp_precision && !ShapeUtil::SameElementType(lhs, rhs))) { - VLOG(3) << "CompareShapes: lhs element type != rhs element type"; - return false; - } - - if (ShapeUtil::IsTuple(lhs)) { - return absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), - [=](const Shape& l, const Shape& r) { - return CompareShapes(l, r, compare_layouts, - ignore_fp_precision); - }); - } else if (!ShapeUtil::IsArray(lhs)) { - // Non-tuple, non-array tupes such as opaque and token types are trivially - // the same. - return true; - } - - if (compare_layouts) { - if (lhs.layout().format() != rhs.layout().format()) { - return false; - } - if (LayoutUtil::IsDenseArray(lhs)) { - if (!absl::c_equal(LayoutUtil::MinorToMajor(lhs), - LayoutUtil::MinorToMajor(rhs))) { - VLOG(3) << "CompareShapes: lhs layout != rhs layout"; - return false; - } - - const auto& lhs_tiles = lhs.layout().tiles(); - const auto& rhs_tiles = rhs.layout().tiles(); - if (lhs_tiles.size() != rhs_tiles.size()) { - return false; - } - for (int64 i = 0; i < lhs_tiles.size(); i++) { - if (!absl::c_equal(lhs_tiles[i].dimensions(), - rhs_tiles[i].dimensions())) { - return false; - } - } - - if (lhs.layout().element_size_in_bits() != - rhs.layout().element_size_in_bits()) { - return false; - } - } - } - - if (!ShapeUtil::SameDimensions(lhs, rhs)) { - VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; - return false; - } - return true; -} - // Constructs and returns the new shape with the given minor_to_major order in // its Layout. StatusOr MakeShapeWithLayoutInternal( @@ -174,12 +111,11 @@ StatusOr MakeShapeWithLayoutInternal( TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); return shape; } - } // namespace /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) { - bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true, - /*ignore_fp_precision=*/false); + bool equal = Shape::Equal()(lhs, rhs); + if (!equal && VLOG_IS_ON(3)) { VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString(); @@ -190,8 +126,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs) { - bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true, - /*ignore_fp_precision=*/true); + bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs); if (!equal && VLOG_IS_ON(3)) { VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = " << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString(); @@ -200,12 +135,6 @@ StatusOr MakeShapeWithLayoutInternal( return equal; } -/* static */ int64 ShapeUtil::Rank(const Shape& shape) { - CHECK(ShapeUtil::IsArray(shape)) - << "Non-arrays do not have a rank, shape: " << shape; - return shape.dimensions_size(); -} - /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) { int64 accum = 0; for (int64 dimension : shape.dimensions()) { @@ -232,6 +161,13 @@ StatusOr MakeShapeWithLayoutInternal( return MakeValidatedShape(element_type, dimensions).ValueOrDie(); } +/* static */ Shape ShapeUtil::MakeShape( + PrimitiveType element_type, absl::Span dimensions, + const std::vector& dynamic_dimensions) { + return MakeValidatedShape(element_type, dimensions, dynamic_dimensions) + .ValueOrDie(); +} + /* static */ StatusOr ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions) { CHECK(IsArrayPrimitiveType(element_type)) << element_type; @@ -240,6 +176,17 @@ StatusOr MakeShapeWithLayoutInternal( return result; } +/* static */ StatusOr ShapeUtil::MakeValidatedShape( + PrimitiveType element_type, absl::Span dimensions, + const std::vector& dynamic_dimensions) { + TF_ASSIGN_OR_RETURN(Shape shape, + MakeValidatedShape(element_type, dimensions)); + for (int i = 0; i < dynamic_dimensions.size(); ++i) { + shape.set_dynamic_dimension(i, dynamic_dimensions[i]); + } + return shape; +} + /* static */ Shape ShapeUtil::MakeShapeWithLayout( PrimitiveType element_type, absl::Span dimensions, absl::Span minor_to_major) { @@ -319,7 +266,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) { CHECK(LayoutUtil::IsDenseArray(*shape)); - shape->mutable_layout()->add_minor_to_major(Rank(*shape)); + shape->mutable_layout()->add_minor_to_major(shape->rank()); shape->add_dimensions(bound); TF_DCHECK_OK(ValidateShape(*shape)); } @@ -334,7 +281,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) { - if (!IsArray(shape)) { + if (!shape.IsArray()) { return false; } return primitive_util::BitWidth(shape.element_type()) == bits; @@ -358,6 +305,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( case U32: case U64: case C64: + case C128: case TUPLE: case OPAQUE: case TOKEN: @@ -376,27 +324,24 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return primitive_util::IsFloatingPointType(shape.element_type()); } -/* static */ bool ShapeUtil::IsArray(const Shape& shape) { - return IsArrayPrimitiveType(shape.element_type()); -} - /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) { - return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(), - shape.tuple_shapes().end(), IsTuple); + return shape.IsTuple() && + absl::c_any_of(shape.tuple_shapes(), + [](const Shape& s) { return s.IsTuple(); }); } /* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) { - return IsTuple(shape) && TupleElementCount(shape) == 0; + return shape.IsTuple() && TupleElementCount(shape) == 0; } /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) { - CHECK(IsTuple(shape)) << HumanString(shape); + CHECK(shape.IsTuple()) << HumanString(shape); return shape.tuple_shapes_size(); } /* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape, int64 index) { - CHECK(IsTuple(shape)); + CHECK(shape.IsTuple()); CHECK_GT(TupleElementCount(shape), index); TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index))); return shape.tuple_shapes(index); @@ -412,7 +357,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start, int64 limit) { TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple)); - CHECK(IsTuple(tuple)); + CHECK(tuple.IsTuple()); CHECK_LE(start, TupleElementCount(tuple)); CHECK_LE(limit, TupleElementCount(tuple)); @@ -429,15 +374,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( complex_shape.element_type())); } -/* static */ bool ShapeUtil::ShapeIs(const Shape& shape, - PrimitiveType element_type, - std::initializer_list dimensions) { - return Equal(shape, MakeShape(element_type, dimensions)); -} - /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { - DCHECK(IsArray(shape)) << ShapeUtil::HumanString(shape); - DCHECK_EQ(shape.dimensions_size(), Rank(shape)); + DCHECK(shape.IsArray()) << ShapeUtil::HumanString(shape); + DCHECK_EQ(shape.dimensions_size(), shape.rank()); if (shape.dimensions().size() == 1) { return shape.dimensions()[0]; } @@ -447,8 +386,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ int64 ShapeUtil::ElementsInRecursive(const Shape& shape) { - CHECK(IsArray(shape) || IsTuple(shape)); - if (IsArray(shape)) { + CHECK(shape.IsArray() || shape.IsTuple()); + if (shape.IsArray()) { return ElementsIn(shape); } int64 count = 0; @@ -472,7 +411,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) { - return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0; + return shape.IsArray() && ElementsIn(shape) == 0; } /* static */ bool ShapeUtil::IsScalarWithElementType( @@ -481,7 +420,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ string ShapeUtil::HumanString(const Shape& shape) { - if (IsTuple(shape)) { + if (shape.IsTuple()) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { @@ -491,13 +430,21 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( text += ")"; return text; } + std::vector dim_elements; + for (int i = 0; i < shape.dimensions_size(); ++i) { + if (shape.is_dynamic_dimension(i)) { + dim_elements.push_back(StrCat("<=", shape.dimensions(i))); + } else { + dim_elements.push_back(StrCat(shape.dimensions(i))); + } + } return StrCat( primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[", - absl::StrJoin(shape.dimensions(), ","), "]"); + absl::StrJoin(dim_elements, ","), "]"); } /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { - if (IsTuple(shape)) { + if (shape.IsTuple()) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { @@ -510,10 +457,11 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( string result = StrCat( primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "["); for (int i = 0; i < shape.dimensions().size(); i++) { - StrAppend(&result, (i > 0) ? "," : "", shape.dimensions(i)); + StrAppend(&result, (i > 0) ? "," : "", + shape.is_dynamic_dimension(i) ? "<=" : "", shape.dimensions(i)); } result += "]"; - if (!IsScalar(shape) && IsArray(shape)) { + if (!IsScalar(shape) && shape.IsArray()) { if (LayoutUtil::HasLayout(shape)) { StrAppend(&result, LayoutUtil::HumanString(shape.layout())); } @@ -536,43 +484,23 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs, const Shape& rhs) { - CHECK(ShapeUtil::IsArray(lhs)); - CHECK(ShapeUtil::IsArray(rhs)); + CHECK(lhs.IsArray()); + CHECK(rhs.IsArray()); return absl::c_equal(lhs.dimensions(), rhs.dimensions()); } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { - return CompareShapes(lhs, rhs, /*compare_layouts=*/false, - /*ignore_fp_precision=*/false); + return Shape::Equal().IgnoreLayout()(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs) { - if (IsArray(lhs)) { - return IsArray(rhs) && SameDimensions(lhs, rhs); - } else if (lhs.element_type() == TUPLE) { - return rhs.element_type() == TUPLE && - absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), - CompatibleIgnoringElementType); - } else { - // Opaque, token, etc types are vacuously compatible. - return lhs.element_type() == rhs.element_type(); - } + return Shape::Equal().IgnoreElementType().IgnoreLayout()(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs) { - if (IsArray(lhs)) { - return IsArray(rhs) && SameElementTypeIgnoringFpPrecision(lhs, rhs) && - CompatibleIgnoringElementType(lhs, rhs); - } else if (lhs.element_type() == TUPLE) { - return rhs.element_type() == TUPLE && - absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), - CompatibleIgnoringFpPrecision); - } else { - // Opaque, token, etc types are vacuously compatible. - return lhs.element_type() == rhs.element_type(); - } + return Shape::Equal().IgnoreFpPrecision().IgnoreLayout()(lhs, rhs); } /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, @@ -583,7 +511,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape, int64 dimension_number) { if (dimension_number < 0) { - dimension_number += Rank(shape); + dimension_number += shape.rank(); } CHECK_GE(dimension_number, 0); return dimension_number; @@ -620,6 +548,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return sizeof(double); case C64: return sizeof(complex64); + case C128: + return sizeof(complex128); case TOKEN: // Tokens require no space. return 0; @@ -637,7 +567,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( TF_DCHECK_OK(ValidateShape(shape)); if (shape.element_type() == TUPLE) { return ByteSizeOfTupleIndexTable(shape, pointer_size); - } else if (IsArray(shape)) { + } else if (shape.IsArray()) { int64 byte_size = ByteSizeOfElements(shape); if (LayoutUtil::IsSparseArray(shape)) { byte_size += ByteSizeOfSparseIndices(shape); @@ -663,7 +593,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) { TF_DCHECK_OK(ValidateShape(shape)); - CHECK(ShapeUtil::IsArray(shape)); + CHECK(shape.IsArray()); int64 allocated_element_count; if (LayoutUtil::IsSparseArray(shape)) { @@ -679,8 +609,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) { TF_DCHECK_OK(ValidateShape(shape)); CHECK(LayoutUtil::IsSparseArray(shape)); - return LayoutUtil::MaxSparseElements(shape.layout()) * - ShapeUtil::Rank(shape) * sizeof(int64); + return LayoutUtil::MaxSparseElements(shape.layout()) * shape.rank() * + sizeof(int64); } /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( @@ -723,10 +653,10 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return Status::OK(); } - if (LayoutUtil::IsSparseArray(shape) && Rank(shape) == 0) { + if (LayoutUtil::IsSparseArray(shape) && shape.rank() == 0) { return InvalidArgument("sparse arrays must have rank > 0"); } - for (int64 i = 0; i < Rank(shape); ++i) { + for (int64 i = 0; i < shape.rank(); ++i) { int64 dimension = shape.dimensions(i); if (dimension < 0) { return InvalidArgument( @@ -742,7 +672,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) { VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape); - if (!IsArray(shape)) { + if (!shape.IsArray()) { return Status::OK(); } @@ -763,7 +693,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return sparse_elements_size; } int64 sparse_indices_size = - MultiplyWithoutOverflow(max_sparse_elements, ShapeUtil::Rank(shape)); + MultiplyWithoutOverflow(max_sparse_elements, shape.rank()); if (sparse_indices_size < 0) { return sparse_indices_size; } @@ -835,7 +765,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( ShapeIndexView index) { const Shape* subshape = &shape; for (auto i : index) { - if (!IsTuple(*subshape) || i >= subshape->tuple_shapes_size() || i < 0) { + if (!subshape->IsTuple() || i >= subshape->tuple_shapes_size() || i < 0) { return false; } subshape = &subshape->tuple_shapes(i); @@ -847,7 +777,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( ShapeIndexView index) { const Shape* return_shape = &shape; for (auto i : index) { - CHECK(IsTuple(*return_shape)) + CHECK(return_shape->IsTuple()) << "Invalid index " << index << " for shape " << shape; return_shape = &return_shape->tuple_shapes(i); } @@ -858,7 +788,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( const Shape& shape, ShapeIndexView index) { const Shape* return_shape = &shape; for (auto i : index) { - if (!IsTuple(*return_shape) || i < 0 || + if (!return_shape->IsTuple() || i < 0 || i >= return_shape->tuple_shapes_size()) { return InvalidArgument( "Shape index %s not a valid subshape index for tuple with shape %s", @@ -873,7 +803,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( ShapeIndexView index) { Shape* return_shape = shape; for (auto i : index) { - CHECK(IsTuple(*return_shape)); + CHECK(return_shape->IsTuple()); return_shape = return_shape->mutable_tuple_shapes(i); } return return_shape; @@ -881,11 +811,11 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { - return !IsTuple(GetSubshape(shape, index)); + return !GetSubshape(shape, index).IsTuple(); } /* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) { - if (!IsTuple(shape)) { + if (!shape.IsTuple()) { return 1; } int64 count = 0; @@ -907,7 +837,7 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { } /* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) { - CHECK(ShapeUtil::IsArray(shape)); + CHECK(shape.IsArray()); return absl::c_linear_search(shape.dimensions(), 1); } @@ -924,7 +854,7 @@ Status ForEachSubshapeHelper(const Shape& shape, const ShapeUtil::StatusVisitorFunction& func, ShapeIndex* index) { TF_RETURN_IF_ERROR(func(shape, *index)); - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { index->push_back(i); TF_RETURN_IF_ERROR(ForEachSubshapeHelper( @@ -941,7 +871,7 @@ Status ForEachMutableSubshapeHelper( Shape* shape, const ShapeUtil::MutatingStatusVisitorFunction& func, ShapeIndex* index) { TF_RETURN_IF_ERROR(func(shape, *index)); - if (ShapeUtil::IsTuple(*shape)) { + if (shape->IsTuple()) { for (int64 i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) { index->push_back(i); TF_RETURN_IF_ERROR(ForEachMutableSubshapeHelper( @@ -999,6 +929,10 @@ Status ForEachMutableSubshapeHelper( for (auto dim : Permute(permutation, shape.dimensions())) { new_shape.add_dimensions(dim); } + for (int64 i = 0; i < shape.rank(); i++) { + new_shape.set_dynamic_dimension(permutation[i], + shape.is_dynamic_dimension(i)); + } // If `shape` has a layout, by contract we choose a new layout such that the // transpose defined by this permutation is a bitcast. @@ -1049,8 +983,8 @@ Status ForEachMutableSubshapeHelper( /* static */ std::tuple, std::vector> ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, const Shape& shape_post) { - CHECK(IsArray(shape_pre)); - CHECK(IsArray(shape_post)); + CHECK(shape_pre.IsArray()); + CHECK(shape_post.IsArray()); auto nil = std::make_tuple(false, std::vector(), std::vector()); @@ -1097,7 +1031,7 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, auto unmodified_dim_pair = i < unmodified_dims.size() ? unmodified_dims[i] - : std::make_pair(Rank(shape_pre), Rank(shape_post)); + : std::make_pair(shape_pre.rank(), shape_post.rank()); if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) { return nil; } @@ -1109,8 +1043,8 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, /* static */ std::vector> ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, const Shape& output_shape) { - CHECK(IsArray(input_shape)); - CHECK(IsArray(output_shape)); + CHECK(input_shape.IsArray()); + CHECK(output_shape.IsArray()); // Unmodified dimensions are merely common factors of rank 1. auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()), @@ -1160,8 +1094,8 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape) { - CHECK(IsArray(input_shape)); - CHECK(IsArray(output_shape)); + CHECK(input_shape.IsArray()); + CHECK(output_shape.IsArray()); CHECK(LayoutUtil::HasLayout(input_shape)); CHECK(LayoutUtil::HasLayout(output_shape)); @@ -1289,12 +1223,12 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, Shape output_shape_dim0_major = MakeShapeWithDescendingLayout( output_shape.element_type(), AsInt64Slice(output_shape.dimensions())); - for (int64 input_dim = 0; input_dim < Rank(input_shape); ++input_dim) { + for (int64 input_dim = 0; input_dim < input_shape.rank(); ++input_dim) { if (input_shape.dimensions(input_dim) <= 1) { continue; } - std::vector input_unit_index(Rank(input_shape), 0); + std::vector input_unit_index(input_shape.rank(), 0); input_unit_index[input_dim] = 1; int64 logical_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major, @@ -1320,11 +1254,11 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ absl::optional ShapeUtil::AlignLayouts( const Shape& input_shape, const Shape& output_shape) { - CHECK(IsArray(input_shape)); - CHECK(IsArray(output_shape)); + CHECK(input_shape.IsArray()); + CHECK(output_shape.IsArray()); - int64 input_rank = Rank(input_shape); - int64 output_rank = Rank(output_shape); + int64 input_rank = input_shape.rank(); + int64 output_rank = output_shape.rank(); // First, calculate an alignment of the dimensions. A consecutive sequence of // input dimensions and output dimensions belong to the same alignment part if @@ -1461,30 +1395,14 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete, Shape shape) { - CHECK(IsArray(shape)); - shape.mutable_dimensions()->erase(shape.mutable_dimensions()->begin() + - dim_to_delete); - if (LayoutUtil::HasLayout(shape)) { - Layout* layout = shape.mutable_layout(); - layout->set_format(DENSE); - for (int64 i = 0; i < layout->minor_to_major().size();) { - if (layout->minor_to_major(i) == dim_to_delete) { - layout->mutable_minor_to_major()->erase( - layout->mutable_minor_to_major()->begin() + i); - continue; - } - if (layout->minor_to_major(i) > dim_to_delete) { - (*layout->mutable_minor_to_major())[i] -= 1; - } - ++i; - } - } + CHECK(shape.IsArray()); + shape.DeleteDimension(dim_to_delete); return shape; } /* static */ Shape ShapeUtil::FilterDimensions( const std::function& p, Shape shape) { - CHECK(IsArray(shape)); + CHECK(shape.IsArray()); std::vector dims_to_delete; for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) { if (!p(i)) { @@ -1504,8 +1422,11 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, size_t hash_value = hash()(shape.element_type()); if (shape.tuple_shapes().empty()) { - for (int64 dim : shape.dimensions()) { - hash_value = Hash64Combine(hash_value, hash()(dim)); + for (int i = 0; i < shape.dimensions_size(); ++i) { + hash_value = + Hash64Combine(hash_value, hash()(shape.dimensions(i))); + hash_value = Hash64Combine(hash_value, + hash()(shape.is_dynamic_dimension(i))); } hash_value = Hash64Combine(hash_value, LayoutUtil::Hash(shape.layout())); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 6b7a9cd34f25f2088bdb8d2c7f0412e5d8519d23..fb6da7460e2475732d6f02758e5519fbdb7c0f8d 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -185,7 +185,7 @@ class ShapeUtil { // may not actually be able to store this number of elements. See // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of // elements that can be stored in a sparse shape. - // Precondition: IsArray(shape) + // Precondition: shape.IsArray() static int64 ElementsIn(const Shape& shape); // As ElementsIn(), but recurses through tuples. @@ -207,7 +207,7 @@ class ShapeUtil { // Returns the number of bytes used to store the primitive_type. // - // Precondition: ShapeUtil::IsArray(shape) + // Precondition: shape.IsArray() static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type); // Returns the number of bytes required to store the tuple member pointers for @@ -262,7 +262,7 @@ class ShapeUtil { } // Returns the higher-precision element type if a and b are both floating - // point types; otherwise, checks that that they have the same element type + // point types; otherwise, checks that they have the same element type // and returns it. static PrimitiveType HigherPrecisionElementType(const Shape& a, const Shape& b) { @@ -290,16 +290,12 @@ class ShapeUtil { // being F32. Tuple elements are compared recursively for compatibility. static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); - // Returns whether the lhs and rhs shapes are identical protobufs. + // Returns whether the lhs and rhs shapes are identical. static bool Equal(const Shape& lhs, const Shape& rhs); // As Equal, but allow one of lhs and rhs to be F16 while the other is F32. static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); - // Returns the rank (number of dimensions) of the given shape. - // Precondition: !IsTuple(shape) - static int64 Rank(const Shape& shape); - // Returns the number of dimensions for which the dimension is not (trivially) // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just // fluff. Note that zero dimensions are included in the true rank, e.g., @@ -313,10 +309,10 @@ class ShapeUtil { // Scalar-specific static bool IsScalar(const Shape& shape) { - return IsArray(shape) && Rank(shape) == 0; + return shape.IsArray() && shape.rank() == 0; } static bool IsEffectiveScalar(const Shape& shape) { - return IsArray(shape) && TrueRank(shape) == 0; + return shape.IsArray() && TrueRank(shape) == 0; } // Returns whether "shape" is a scalar (array) with the given element_type. @@ -371,11 +367,24 @@ class ShapeUtil { static Shape MakeShape(PrimitiveType element_type, absl::Span dimensions); + // Constructs a new shape with the given element type and sequence of + // potentially dynamic dimensions. The argument 'dynamic_dimensions' indicates + // with a true value that the respective dimension is dynamic. If the + // dimension is dynamic then the respective value in 'dimension' is an upper + // bound on the dimension size. 'dimensions' and 'dynamic_dimensions' must be + // the same size. + static Shape MakeShape(PrimitiveType element_type, + absl::Span dimensions, + const std::vector& dynamic_dimensions); + // Constructs a new shape with the given element type and sequence of // dimensions. Method checks if the element type is valid and the shape's // size fits in std::numeric_limits::max(). static StatusOr MakeValidatedShape(PrimitiveType element_type, absl::Span dimensions); + static StatusOr MakeValidatedShape( + PrimitiveType element_type, absl::Span dimensions, + const std::vector& dynamic_dimensions); // Creates a Shape with element type corresponding to T and the given // dimensions @@ -443,27 +452,6 @@ class ShapeUtil { // that floating point numbers are signed. static bool ElementIsSigned(const Shape& shape); - // Returns whether the shape is a tuple. - static bool IsTuple(const Shape& shape) { - return shape.element_type() == TUPLE; - } - - // Returns whether the shape is an opaque value (i.e. an 'existential' typed - // value that is passed to CustomCall operations). - static bool IsOpaque(const Shape& shape) { - return shape.element_type() == OPAQUE; - } - - // Returns whether the shape is an token value used for ordering - // side-effecting operations. - static bool IsToken(const Shape& shape) { - return shape.element_type() == TOKEN; - } - - // Returns whether the shape is an array. Note that scalars are considered - // arrays. - static bool IsArray(const Shape& shape); - // Returns whether the given primitive type corresponds to an array shape. static bool IsArrayPrimitiveType(PrimitiveType primitive_type); @@ -493,12 +481,6 @@ class ShapeUtil { // shape. static Shape ComplexComponentShape(const Shape& complex_shape); - // Shorthand for testing whether a shape is of a given element type and - // sequence of dimensions. - ABSL_DEPRECATED("Use Equal() instead.") - static bool ShapeIs(const Shape& shape, PrimitiveType element_type, - std::initializer_list dimensions); - // Returns true if the given shape has a subshape at the given index. static bool IndexIsValid(const Shape& shape, ShapeIndexView index); @@ -693,11 +675,9 @@ class ShapeUtil { template static void ForEachIndex(const Shape& shape, const FnType& visitor_function) { - ForEachIndexWithStatus(shape, - [&](absl::Span indices) { - return StatusOr(visitor_function(indices)); - }) - .IgnoreError(); + ForEachIndexWithStatus(shape, [&](absl::Span indices) { + return StatusOr(visitor_function(indices)); + }).IgnoreError(); } // A parallel version of ForEachIndex(WithStatus). This can only be used if @@ -746,7 +726,7 @@ class ShapeUtil { if (ShapeUtil::IsZeroElementArray(shape)) { return Status::OK(); } - CHECK_EQ(Rank(shape), base.size()); + CHECK_EQ(shape.rank(), base.size()); CHECK_EQ(incr.size(), base.size()); CHECK_EQ(count.size(), base.size()); const int64 rank = LayoutUtil::MinorToMajor(shape).size(); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 0a3081f5161f80ac97e864ba08d186df4fbdb55d..126ae58293d12182e9b6e30f779f681829729526 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -176,6 +176,28 @@ TEST(ShapeUtilTest, UnequalIgnoringFpPrecision) { ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1}))); } +TEST(ShapeUtilTest, EqualDynamicShapes) { + EXPECT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 3}, {true, false}), + ShapeUtil::MakeShape(F32, {4, 3}, {true, false}))); + EXPECT_FALSE( + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 3}, {true, false}), + ShapeUtil::MakeShape(F32, {4, 3}, {false, false}))); +} + +TEST(ShapeUtilTest, CompatibleDynamicShapes) { + Shape shape_a = ShapeUtil::MakeShape(F32, {4, 3}, {true, false}); + *shape_a.mutable_layout() = Layout({1, 0}); + Shape shape_b = ShapeUtil::MakeShape(F32, {4, 3}, {true, false}); + *shape_b.mutable_layout() = Layout({0, 1}); + Shape shape_c = ShapeUtil::MakeShape(F32, {4, 3}, {false, true}); + *shape_c.mutable_layout() = Layout({0, 1}); + + EXPECT_TRUE(ShapeUtil::Compatible(shape_a, shape_a)); + EXPECT_TRUE(ShapeUtil::Compatible(shape_a, shape_b)); + EXPECT_FALSE(ShapeUtil::Compatible(shape_a, shape_c)); +} + TEST(ShapeUtilTest, CompatibleTuples) { Shape tuple1 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})}); @@ -516,10 +538,6 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) { ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2))); } -TEST(ShapeUtilTest, ShapeIs) { - EXPECT_FALSE(ShapeUtil::ShapeIs(ShapeUtil::MakeShape(PRED, {2}), PRED, {})); -} - TEST(ShapeUtilTest, ForEachIndex) { struct ShapeDimensionAndNumberInvocations { std::vector dimensions; @@ -692,6 +710,26 @@ TEST(ShapeUtilTest, PermuteDimensionsLayout) { } while (std::next_permutation(layout.begin(), layout.end())); } +TEST(ShapeUtilTest, PermuteDynamicDimensions) { + Shape shape = + ShapeUtil::MakeShape(F32, {10, 100, 1000}, + /*dynamic_dimensions*/ {false, true, true}); + SCOPED_TRACE(absl::StrCat("shape=", shape.ToString())); + + std::vector permutation(3); + std::iota(permutation.begin(), permutation.end(), 0); + do { + SCOPED_TRACE(absl::StrCat("permutation=", absl::StrJoin(permutation, ","))); + + auto permuted = ShapeUtil::PermuteDimensions(permutation, shape); + for (int i = 0; i < shape.rank(); i++) { + EXPECT_EQ(permuted.dimensions(permutation[i]), shape.dimensions(i)); + EXPECT_EQ(permuted.is_dynamic_dimension(permutation[i]), + shape.is_dynamic_dimension(i)); + } + } while (std::next_permutation(permutation.begin(), permutation.end())); +} + TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast( ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}), diff --git a/tensorflow/compiler/xla/sparse_index_array.cc b/tensorflow/compiler/xla/sparse_index_array.cc index a40bb7875e7ea53a8959a9a67ec09ec260ba9c37..82091bdee65c709bb6020f40acc15f13d8599c1d 100644 --- a/tensorflow/compiler/xla/sparse_index_array.cc +++ b/tensorflow/compiler/xla/sparse_index_array.cc @@ -79,7 +79,7 @@ void SparseIndexArray::Resize(int64 num_indices) { } bool SparseIndexArray::Validate(const Shape& shape) const { - if (rank_ == 0 || rank_ != ShapeUtil::Rank(shape)) { + if (rank_ == 0 || rank_ != shape.rank()) { return false; } int64 num_indices = index_count(); diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h index a96d483462efd77ae4761541e8c79b2c84fa49f3..0c25355467da3fd346d80db790d78252869975ef 100644 --- a/tensorflow/compiler/xla/sparse_index_array.h +++ b/tensorflow/compiler/xla/sparse_index_array.h @@ -135,7 +135,7 @@ void SparseIndexArray::SortWithValues(absl::Span values) { auto sort_order_less = [this](int64 lhs, int64 rhs) { return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0; }; - std::sort(sort_order.begin(), sort_order.end(), sort_order_less); + absl::c_sort(sort_order, sort_order_less); // Reorder the array elements according to sort_order. Work through the array // and follow cycles so we can do the reorder in-place. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index ee24d4d99cb1f7ce51a72c6258cbadd6adf12f81..8fb674255020ced6bfaf5f004758ed48f8621a65 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -71,6 +71,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:transfer_manager", @@ -276,9 +277,6 @@ cc_library( xla_test( name = "bad_rng_shape_validation_test", srcs = ["bad_rng_shape_validation_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -315,6 +313,26 @@ xla_test( ], ) +xla_test( + name = "conv_depthwise_backprop_filter_test", + timeout = "long", + srcs = ["conv_depthwise_backprop_filter_test.cc"], + shard_count = 6, + deps = [ + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:bfloat16_normalization", + "//tensorflow/compiler/xla/service:despecializer", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/types:optional", + ], +) + xla_test( name = "grouped_convolution_test", timeout = "long", @@ -344,9 +362,6 @@ xla_test( xla_test( name = "check_execution_arity_test", srcs = ["check_execution_arity_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -367,9 +382,6 @@ xla_test( xla_test( name = "query_inferred_shape_test", srcs = ["query_inferred_shape_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -387,9 +399,6 @@ xla_test( xla_test( name = "while_test", srcs = ["while_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -413,6 +422,10 @@ xla_test( xla_test( name = "xla_hlo_profile_test", srcs = ["xla_hlo_profile_test.cc"], + blacklisted_backends = [ + # Hlo profiles are not supported on the interpreter backend. + "interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:shape_util", @@ -436,9 +449,6 @@ xla_test( xla_test( name = "axpy_simple_test", srcs = ["axpy_simple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -453,7 +463,6 @@ xla_test( xla_test( name = "map_test", srcs = ["map_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -506,9 +515,6 @@ xla_test( xla_test( name = "pred_test", srcs = ["pred_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla/client:local_client", @@ -524,9 +530,6 @@ xla_test( xla_test( name = "select_test", srcs = ["select_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", @@ -544,7 +547,6 @@ xla_test( xla_test( name = "conditional_test", srcs = ["conditional_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", @@ -562,7 +564,6 @@ xla_test( xla_test( name = "unary_op_test", srcs = ["unary_op_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", @@ -623,9 +624,6 @@ xla_test( xla_test( name = "deconstruct_tuple_test", srcs = ["deconstruct_tuple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -648,7 +646,6 @@ xla_test( name = "array_elementwise_ops_test", srcs = ["array_elementwise_ops_test.cc"], shard_count = 25, - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -698,7 +695,6 @@ xla_test( xla_test( name = "reduce_precision_test", srcs = ["reduce_precision_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -725,7 +721,6 @@ xla_test( srcs = ["dot_operation_test.cc"], shard_count = 20, tags = [ - "enable_for_xla_interpreter", "optonly", ], deps = [ @@ -735,7 +730,9 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -792,7 +789,9 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -806,9 +805,6 @@ xla_test( xla_test( name = "transpose_test", srcs = ["transpose_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", @@ -828,9 +824,6 @@ xla_test( xla_test( name = "constants_test", srcs = ["constants_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -841,6 +834,7 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -951,6 +945,11 @@ xla_test( xla_test( name = "batch_normalization_test", srcs = ["batch_normalization_test.cc"], + blacklisted_backends = [ + # BatchNorm HLOs are not handled by the interpreter backend, and the + # BatchNorm expander is not run on the interpreter. + "interpreter", + ], shard_count = 40, deps = [ ":test_utils", @@ -1042,9 +1041,6 @@ xla_test( name = "slice_test", srcs = ["slice_test.cc"], shard_count = 40, - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", @@ -1065,9 +1061,6 @@ xla_test( xla_test( name = "multidimensional_slice_test", srcs = ["multidimensional_slice_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -1085,9 +1078,6 @@ xla_test( name = "dynamic_ops_test", timeout = "moderate", srcs = ["dynamic_ops_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", @@ -1113,9 +1103,6 @@ xla_test( xla_test( name = "tuple_test", srcs = ["tuple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -1139,9 +1126,6 @@ xla_test( xla_test( name = "vector_ops_reduce_test", srcs = ["vector_ops_reduce_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -1162,7 +1146,6 @@ xla_test( srcs = ["reduce_test.cc"], shard_count = 40, tags = [ - "enable_for_xla_interpreter", "optonly", ], deps = [ @@ -1229,7 +1212,6 @@ xla_test( srcs = [], shard_count = 20, tags = [ - "enable_for_xla_interpreter", "optonly", ], xla_test_library_deps = [":reduce_window_test_library"], @@ -1241,7 +1223,6 @@ xla_test( timeout = "long", srcs = ["select_and_scatter_test.cc"], tags = [ - "enable_for_xla_interpreter", "optonly", ], deps = [ @@ -1267,9 +1248,6 @@ xla_test( xla_test( name = "copy_test", srcs = ["copy_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ ":client_library_test_base", "//tensorflow/compiler/xla:array2d", @@ -1290,9 +1268,6 @@ xla_test( xla_test( name = "reduce_hlo_test", srcs = ["reduce_hlo_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1306,9 +1281,6 @@ xla_test( xla_test( name = "token_hlo_test", srcs = ["token_hlo_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", @@ -1323,9 +1295,6 @@ xla_test( xla_test( name = "call_test", srcs = ["call_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -1368,9 +1337,6 @@ xla_test( xla_test( name = "binop_scaling_test", srcs = ["binop_scaling_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1388,9 +1354,6 @@ xla_test( xla_test( name = "broadcast_simple_test", srcs = ["broadcast_simple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1410,9 +1373,6 @@ xla_test( xla_test( name = "pad_test", srcs = ["pad_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1434,9 +1394,6 @@ xla_test( xla_test( name = "fmax_test", srcs = ["fmax_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -1451,9 +1408,6 @@ xla_test( xla_test( name = "log_test", srcs = ["log_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -1468,9 +1422,6 @@ xla_test( xla_test( name = "matrix_ops_simple_test", srcs = ["matrix_ops_simple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -1517,9 +1468,6 @@ xla_test( name = "reshape_test", srcs = ["reshape_test.cc"], shard_count = 30, - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1545,9 +1493,6 @@ xla_test( xla_test( name = "reverse_test", srcs = ["reverse_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1566,9 +1511,6 @@ xla_test( xla_test( name = "vector_ops_simple_test", srcs = ["vector_ops_simple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:shape_util", @@ -1592,9 +1534,6 @@ xla_test( xla_test( name = "concat_test", srcs = ["concat_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -1615,9 +1554,6 @@ xla_test( xla_test( name = "convert_test", srcs = ["convert_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -1637,6 +1573,10 @@ xla_test( xla_test( name = "all_reduce_test", srcs = ["all_reduce_test.cc"], + blacklisted_backends = [ + # All reduce is not supported on the interpreter backend. + "interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1661,9 +1601,6 @@ xla_test( xla_test( name = "bitcast_convert_test", srcs = ["bitcast_convert_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -1703,9 +1640,6 @@ xla_test( xla_test( name = "floor_ceil_test", srcs = ["floor_ceil_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -1767,6 +1701,10 @@ xla_test( xla_test( name = "execution_profile_test", srcs = ["execution_profile_test.cc"], + blacklisted_backends = [ + # Execution profiles are not supported on the interpreter backend. + "interpreter", + ], deps = [ ":client_library_test_base", "//tensorflow/compiler/xla/client:global_data", @@ -1781,6 +1719,10 @@ xla_test( name = "execution_profile_test_with_xla_hlo_profile", srcs = ["execution_profile_test.cc"], args = ["--xla_hlo_profile"], + blacklisted_backends = [ + # Hlo profiles are not supported on the interpreter backend. + "interpreter", + ], deps = [ ":client_library_test_base", "//tensorflow/compiler/xla/client:global_data", @@ -1794,9 +1736,6 @@ xla_test( xla_test( name = "replay_test", srcs = ["replay_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", @@ -1819,9 +1758,6 @@ xla_test( xla_test( name = "broadcast_test", srcs = ["broadcast_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1883,9 +1819,6 @@ xla_test( xla_test( name = "fusion_test", srcs = ["fusion_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -2003,6 +1936,10 @@ xla_test( xla_test( name = "outfeed_in_nested_computation_test", srcs = ["outfeed_in_nested_computation_test.cc"], + blacklisted_backends = [ + # Outfeed ops are not supported on the interpreter backend. + "interpreter", + ], deps = [ "//tensorflow/compiler/xla/tests:local_client_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -2179,7 +2116,6 @@ xla_test( srcs = ["iota_test.cc"], shard_count = 30, tags = [ - "enable_for_xla_interpreter", # Require optimized builds, iota_test_cpu is very slow in fastbuild. "optonly", ], @@ -2207,3 +2143,18 @@ tf_cc_test( "@com_google_absl//absl/synchronization", ], ) + +xla_test( + name = "ptxas_bug_120501638", + srcs = ["ptxas_bug_120501638.cc"], + tags = [ + # Disabled in OSS until nvidia publicly releases a fixed ptxas. + "no_oss", + ], + deps = [ + ":hlo_test_base", + ":xla_internal_test_main", # fixdeps: keep + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:test", + ], +) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 915b456b52215f8d6a9eb6c5b933f3502f1d3d2c..7379fbcc22745f46f2a29732c4bda46f352d07e7 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -1443,6 +1443,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) { error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, PowC64s) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + auto lhs = + ConstantR1(&builder, {-2.0f, -0.6f, -0.6f, 0.0f, 0.0f, 0.0f}); + auto rhs = + ConstantR1(&builder, {0.5f, 0.6f, -0.6f, 0.5f, 0.6f, 0.0f}); + Pow(lhs, rhs); + + ComputeAndCompareR1(&builder, + { + {0, 1.41421356}, + {-2.27443288e-01, 0.69999846}, + {-4.19847531e-01, -1.29215783}, + {0, 0}, + {0, 0}, + {1, 0}, + }, + {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) { XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {}); @@ -2047,6 +2068,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, ClampF32) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + auto minimum = ConstantR1(&builder, {1.0f, -6.5f, 1.0f, 2.25f, NAN}); + auto argument = + ConstantR1(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); + auto maximum = ConstantR1(&builder, {3.0f, 0.5f, 25.5f, NAN, 123.0f}); + Clamp(minimum, argument, maximum); + + ComputeAndCompareR1(&builder, {2.0f, 0.5f, 1.0f, NAN, NAN}, {}, + error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { XlaBuilder builder(TestName()); auto minimum = ConstantR0(&builder, 0.0f); diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index e9728e636f0ee032416b2da17a3ea83c5bb18083..63e48117056dec4af603cbc85e478fcb15ad0cec 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -76,7 +76,9 @@ XLA_TEST_F(Bfloat16Test, NegateScalarF16) { error_spec_); } -XLA_TEST_F(Bfloat16Test, BatchNormTraining) { +// Disabled on interpreter since BatchNormExanper is not run by default on the +// intepreter backend. +XLA_TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormTraining)) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); @@ -110,7 +112,9 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02)); } -XLA_TEST_F(Bfloat16Test, BatchNormGrad) { +// Disabled on interpreter since BatchNormExanper is not run by default on the +// intepreter backend. +XLA_TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormGrad)) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index a350715597044730429ee9fa268ecd6f2bf26b66..edb95c973b70e30702ed8490c15a48d4d5604170 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -191,7 +191,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( verify_output(actual, ""); // Try with all output layouts. - std::vector minor_to_major(ShapeUtil::Rank(expected.shape())); + std::vector minor_to_major(expected.shape().rank()); std::iota(minor_to_major.begin(), minor_to_major.end(), 0); do { auto layout = ShapeUtil::MakeShapeWithLayout( @@ -224,7 +224,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( TF_ASSIGN_OR_RETURN(auto literal, client_->Transfer(*arguments[index], nullptr)); // Skip tuples because they don't have a rank. - if (ShapeUtil::IsTuple(literal.shape())) { + if (literal.shape().IsTuple()) { layout_strings.push_back( ShapeUtil::HumanStringWithLayout(literal.shape())); arguments_with_layout.push_back(arguments[index]); @@ -234,7 +234,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( return Status::OK(); } - std::vector minor_to_major(ShapeUtil::Rank(literal.shape())); + std::vector minor_to_major(literal.shape().rank()); std::iota(minor_to_major.begin(), minor_to_major.end(), 0); do { auto literal_relayout = diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 65a23dd883594b9bf9c37494a37e9be39b197788..3f65ed7fce4ff4b5c3781ac2581935bfacc69ce1 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -431,7 +431,8 @@ void ClientLibraryTestBase::ComputeAndCompareR0( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, @@ -455,7 +456,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, @@ -480,7 +482,8 @@ void ClientLibraryTestBase::ComputeAndCompareR2( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR2FromArray2D(expected); @@ -506,7 +509,8 @@ void ClientLibraryTestBase::ComputeAndCompareR3( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR3FromArray3D(expected); @@ -532,7 +536,8 @@ void ClientLibraryTestBase::ComputeAndCompareR4( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR4FromArray4D(expected); diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 363dee74b2755a6bdc3c5a5164a85378581c21d2..247328b730f3af936d933f824da491b593b27c90 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -96,7 +96,7 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, LiteralSlice(result, {1})); - EXPECT_TRUE(ShapeUtil::IsTuple(result.shape())); + EXPECT_TRUE(result.shape().IsTuple()); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.shape())); EXPECT_TRUE(ShapeUtil::Equal( @@ -109,7 +109,10 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { /*minor_to_major=*/{1, 0}))); } -XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { +// Disabled for interpreter since ExecuteAsyncOnStream is not implemented on +// interpreter backend. +XLA_TEST_F(ClientTest, + DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(ExecuteParallel))) { XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg; Shape shape = ShapeUtil::MakeShape(S32, {2, 2}); diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 3b0414a6045a7c5f4f75948d8ccf2775c575626e..ef800b8ef624bf1020ff1e6857c13b0387482cd3 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -151,19 +151,35 @@ TEST_F(ComputeConstantTest, DirectParamMissing) { } } -TEST_F(ComputeConstantTest, IndirectParamMissing) { +TEST_F(ComputeConstantTest, GetDimensionSize) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = - Add(ConstantR0(&b, 1.0f), - Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param")); - EXPECT_FALSE(IsConstant(computation, &b)); + auto add = + Add(ConstantR1(&b, {1.0f}), ConstantR1(&b, {1.0f})); + auto get_dimension_size = GetDimensionSize(add, 0); + EXPECT_TRUE(IsConstant(get_dimension_size, &b)); + + TF_ASSERT_OK_AND_ASSIGN(auto value, ComputeConstantScalar( + client, get_dimension_size, &b)); + EXPECT_EQ(value, 1); + } +} - auto value = ComputeConstantScalar(client, computation, &b); - EXPECT_TRUE( - absl::StrContains(value.status().ToString(), "depends on a parameter")) - << value.status(); +TEST_F(ComputeConstantTest, MultipleGetDimensionSize) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto add = + Add(ConstantR2(&b, {{1.0f}}), ConstantR2(&b, {{1.0f}})); + auto get_dimension_size = GetDimensionSize(add, 0); + auto get_dimension_size_2 = GetDimensionSize(add, 0); + auto add_2 = Add(get_dimension_size, get_dimension_size_2); + EXPECT_TRUE(IsConstant(add_2, &b)); + + TF_ASSERT_OK_AND_ASSIGN(auto value, + ComputeConstantScalar(client, add_2, &b)); + EXPECT_EQ(value, 2); } } diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 9174f2651cb90b364f869364fe108cf208c11a84..6530007871ced1d0bbffe2b44ccc8cf9bddd79e1 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -180,6 +181,29 @@ TEST_F(ConstantsTest, Token) { TF_ASSERT_OK(Execute(&builder, {}).status()); } +TEST_F(ConstantsTest, FullLike) { + XlaBuilder b(TestName()); + auto val1 = Iota(&b, F32, 3); + auto val2 = FullLike(val1, 10); + val1 + val2; + ComputeAndCompareR1(&b, {10, 11, 12}, {}, error_spec_); +} + +TEST_F(ConstantsTest, IllegalFullLikeOnTuple) { + XlaBuilder b(TestName()); + auto tuple = Tuple(&b, {Iota(&b, F32, 3), Iota(&b, F32, 1)}); + FullLike(tuple, 10); // Illegal; can't do FullLike on a tuple. + EXPECT_FALSE(b.Build().ok()); +} + +TEST_F(ConstantsTest, FullLikeScalar) { + XlaBuilder b(TestName()); + auto scalar1 = ConstantR0WithType(&b, F32, 1); + auto scalar2 = FullLike(scalar1, 2); + scalar1 - scalar2; + ComputeAndCompareR0(&b, -1, {}, error_spec_); +} + class ConstantsHloTest : public HloTestBase {}; // TODO(b/121147351): Fails on GPU. Not clear if this is expected behavior. @@ -200,9 +224,7 @@ XLA_TEST_F(ConstantsHloTest, DISABLED_ON_GPU(BitcastOfConstant)) { ROOT result = s32[] call(parameter.0, constant-as-scalar), to_apply=func } )"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR0(1); auto result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal(param, result)); diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c78d3f3d9ee2115206e6c4aeeb2991c07e57392a --- /dev/null +++ b/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc @@ -0,0 +1,154 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/despecializer.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +string GetFloatDataType(bool use_bfloat16) { + return use_bfloat16 ? "bf16" : "f32"; +} + +struct DepthwiseConvolution2DSpec { + int64 output_batch, window, window_dilation; + std::vector activation_dims; + std::vector kernel_dims; + std::vector output_dims; +}; + +class DepthwiseConvolution2DTest + : public HloTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> {}; + +static std::vector GetConv2DTestCases() { + std::vector config_set; + std::vector> config_options = { + {8, 5, 3, 2}, {4, 5, 5, 2}, {8, 7, 4, 128}, {16, 20, 20, 256}, + {256, 7, 5, 4}, {256, 6, 6, 4}, {256, 8, 8, 512}}; + + for (auto option : config_options) { + int64 feature = option[3]; + int64 activation_size = option[1]; + int64 kernel_size = option[2]; + int64 batch = option[0]; + + DepthwiseConvolution2DSpec config; + config.window_dilation = 1; + config.output_batch = feature; + config.window = kernel_size; + + config.activation_dims = {batch, activation_size, activation_size, feature}; + + config.kernel_dims = {batch, kernel_size, kernel_size, feature}; + + int64 output_space_size = 3 + activation_size - kernel_size; + config.output_dims = {output_space_size, output_space_size, feature, 1}; + + config_set.push_back(config); + + // Add configurations for window dilation cases. + if (activation_size % 2 == 0 && activation_size == kernel_size) { + DepthwiseConvolution2DSpec config; + config.window_dilation = 2; + config.output_batch = feature; + config.window = kernel_size / 2; + config.activation_dims = {batch, activation_size, activation_size, + feature}; + config.kernel_dims = {batch, kernel_size / 2, kernel_size / 2, feature}; + + int64 output_space_size = 5; + config.output_dims = {output_space_size, output_space_size, feature, 1}; + + config_set.push_back(config); + } + } + + return config_set; +} + +string DepthwiseConvolution2DTestDataToString( + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& spec = ::testing::get<0>(data.param); + const string data_type = GetFloatDataType(::testing::get<1>(data.param)); + string str = absl::StrCat( + "activation_dims_", absl::StrJoin(spec.activation_dims, "x"), + "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), "_output_dims_", + absl::StrJoin(spec.output_dims, "x"), data_type); + + // Test names are not allowed to contain the '-' character. + absl::c_replace(str, '-', 'n'); + return str; +} + +string BuildHloTextDepthwiseConvolution2D( + const DepthwiseConvolution2DSpec& spec, bool use_bfloat16) { + const string data_type = GetFloatDataType(use_bfloat16); + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s] parameter(0) + kernel = %s[%s] parameter(1) + ROOT conv = %s[%s] convolution(%s[%s] activation, %s[%s] kernel), + window={size=%dx%d pad=1_%dx1_%d rhs_dilate=%dx%d}, dim_labels=f01b_i01o->01fb, + batch_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), data_type, + absl::StrJoin(spec.output_dims, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), spec.window, spec.window, + spec.window_dilation, spec.window_dilation, spec.window_dilation, + spec.window_dilation, spec.output_batch); +} + +XLA_TEST_P(DepthwiseConvolution2DTest, DoIt) { + const DepthwiseConvolution2DSpec& spec = ::testing::get<0>(GetParam()); + bool use_bfloat16 = ::testing::get<1>(GetParam()); + const string hlo_text = + BuildHloTextDepthwiseConvolution2D(spec, use_bfloat16); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}, + [](HloModule* module) -> Status { + BFloat16MixedPrecisionRemoval remover; + TF_RETURN_IF_ERROR(remover.Run(module).status()); + Despecializer despecializer; + return despecializer.Run(module).status(); + })); +} + +INSTANTIATE_TEST_CASE_P( + DepthwiseConvolution2DTestWithRandomIndices, DepthwiseConvolution2DTest, + ::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()), + ::testing::Bool()), + DepthwiseConvolution2DTestDataToString); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 249693891290e14645ee5b4b4d97b2d506a01302..9db9f2563b636c4f929585eb13a9c7f809833eda 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -467,8 +467,8 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { // servers. The error message is missing the operator ++. template void iota_int_init_value(std::vector& values, int init_value) { - std::for_each(values.begin(), values.end(), - [&](T& value) { value = static_cast(init_value++); }); + absl::c_for_each(values, + [&](T& value) { value = static_cast(init_value++); }); } template diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index c5d8b663f4abe77e05ec213d2e4e075c260a8655..f740f4815810727890583405b2244fceaec0bd3f 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -19,12 +19,14 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -918,8 +920,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {1, 0}); - auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {1, 6}); + auto one = ConstantR0(&builder, 1); + auto zero = ConstantR0(&builder, 0); + auto dynamic_slice = DynamicSlice(lhs_constant, {one, zero}, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -945,8 +948,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {0, 1}); - auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {6, 1}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(rhs_constant, {zero, one}, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -974,8 +978,9 @@ XLA_TEST_F(DotOperationTest, XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {0, 1}); - auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {6, 1}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(lhs_constant, {zero, one}, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); @@ -1001,8 +1006,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSReverseMM) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {1, 0}); - auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {1, 6}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(rhs_constant, {one, zero}, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); @@ -1033,8 +1039,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSRows) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {0, 1}); - auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {6, 1}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(lhs_constant, {zero, one}, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); @@ -1065,8 +1072,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSRows) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {0, 1}); - auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {6, 1}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(rhs_constant, {zero, one}, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); @@ -1089,8 +1097,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSCols) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {1, 0}); - auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {1, 6}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(lhs_constant, {one, zero}, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -1113,8 +1122,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSCols) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {1, 0}); - auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {1, 6}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(rhs_constant, {one, zero}, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -1147,5 +1157,105 @@ XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) { ComputeAndCompareR2(&builder, expected, {}, error_spec_); } + +using EinsumParamType = + std::tuple, std::vector, string>; +class EinsumTest : public DotOperationTest, + public ::testing::WithParamInterface {}; +XLA_TEST_P(EinsumTest, SimpleEinsumTest) { + XlaBuilder builder(TestName()); + auto x = AddParam( + MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<0>(GetParam()))) + .ValueOrDie(), + &builder); + auto y = AddParam( + MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam()))) + .ValueOrDie(), + &builder); + Einsum(x, y, std::get<2>(GetParam())); + ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3}); +} + +std::vector GetEinsumTestCases() { + using v = std::vector; + using p = EinsumParamType; + std::vector

test_cases = { + p{v{5, 6}, v{6, 7}, "mk,kn->mn"}, + p{v{5, 6}, v{6, 7}, "mk,kn->nm"}, + p{v{5, 6, 11}, v{6, 11, 7}, "mkB,kBn->nmB"}, + p{v{31, 55, 11}, v{55, 11, 29}, "mkB,kBn->nmB"}, + p{v{31, 55, 11}, v{55, 11, 29}, "mkB,kBn->Bnm"}, + p{v{8, 55, 11, 3}, v{55, 11, 3, 29}, "mkBC,kBCn->BCnm"}, + p{v{5, 6}, v{6, 7}, "ab,cd->dcba"}, + p{v{6}, v{6, 7}, "b,bc->c"}, + }; + return test_cases; +} + +INSTANTIATE_TEST_CASE_P(Einsum, EinsumTest, + ::testing::ValuesIn(GetEinsumTestCases())); + +class DotOperationTextTest : public HloTestBase {}; + +XLA_TEST_F(DotOperationTextTest, DotReorderedDotDims) { + absl::string_view hlo_string = + R"( +HloModule ComplexDotMultipleNonContracting + +ENTRY %test { + %lhs = f32[7,17,10,13]{3,2,1,0} parameter(0) + %rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1) + ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3})); +} + +XLA_TEST_F(DotOperationTextTest, DotReorderedDotDimsAndMultipleContracting) { + absl::string_view hlo_string = + R"( +HloModule ComplexDotMultipleNonContracting + +ENTRY %test { + %lhs = f32[7,5,17,10,13]{4,3,2,1,0} parameter(0) + %rhs = f32[7,9,10,13,6,5]{5,4,3,2,1,0} parameter(1) + ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={3,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={1,4}, rhs_contracting_dims={5,3} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3})); +} + +XLA_TEST_F(DotOperationTextTest, DotWithNoDnums) { + absl::string_view hlo_string = + R"( +HloModule DotWithNoDnums + +ENTRY %test { + %lhs = f32[2,3]{1,0} parameter(0) + %rhs = f32[4,5]{1,0} parameter(1) + ROOT %dot = f32[2,3,4,5]{3,2,1,0} dot(%lhs, %rhs) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3})); +} + +XLA_TEST_F(DotOperationTextTest, Einsum) { + absl::string_view hlo_string = + R"( +HloModule Einsum + +ENTRY %test { + %lhs = f32[8,64,96]{2,1,0} parameter(0) + %rhs = f32[96,32,4]{2,1,0} parameter(1) + ROOT %dot = f32[8,64,32,4]{3,2,1,0} dot(%lhs, %rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 7501c6d957e7afe99b8c530e5f0d575f818367da..82e2db36143b2552472fedae701f32389a9be108 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -135,11 +135,11 @@ class DynamicSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); + std::unique_ptr start_data = CreateR0Parameter( + slice_starts[0], 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); - DynamicSlice(input, starts, slice_sizes); + DynamicSlice(input, absl::Span({starts}), slice_sizes); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -160,14 +160,23 @@ class DynamicSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); + std::vector starts(2); + std::vector> start_data(2); + for (int i = 0; i < 2; ++i) { + start_data[i] = CreateR0Parameter( + slice_starts[i], i, "slice_starts", &builder, &starts[i]); + } + // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. - ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); + std::vector argument_ptrs; + absl::c_transform(start_data, std::back_inserter(argument_ptrs), + [](const std::unique_ptr& argument) { + return argument.get(); + }); + ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs); } template @@ -186,14 +195,22 @@ class DynamicSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); + std::vector starts(3); + std::vector> start_data(3); + for (int i = 0; i < 3; ++i) { + start_data[i] = CreateR0Parameter( + slice_starts[i], i, "slice_starts", &builder, &starts[i]); + } // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. - ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); + std::vector argument_ptrs; + absl::c_transform(start_data, std::back_inserter(argument_ptrs), + [](const std::unique_ptr& argument) { + return argument.get(); + }); + ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs); } }; @@ -372,16 +389,12 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { .ValueOrDie()); XlaBuilder builder(TestName()); - // Initialize and transfer dynamic slice start indices parameter. - XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_value); auto update = ConstantLiteral(&builder, update_value); - DynamicUpdateSlice(input, update, starts); + DynamicUpdateSlice(input, update, absl::Span({})); // Run computation and compare against expected values. - ComputeAndCompareLiteral(&builder, expected_value, {start_data.get()}); + ComputeAndCompareLiteral(&builder, expected_value, {}); } template @@ -405,12 +418,12 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); + std::unique_ptr start_data = CreateR0Parameter( + slice_starts[0], 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); auto update = ConstantLiteral(&builder, update_values); - DynamicUpdateSlice(input, update, starts); + DynamicUpdateSlice(input, update, absl::Span({starts})); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -435,15 +448,23 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); + std::vector starts(2); + std::vector> start_data(2); + for (int i = 0; i < 2; ++i) { + start_data[i] = CreateR0Parameter( + slice_starts[i], i, "slice_starts", &builder, &starts[i]); + } // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); auto update = ConstantLiteral(&builder, update_values); DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. - ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); + std::vector argument_ptrs; + absl::c_transform(start_data, std::back_inserter(argument_ptrs), + [](const std::unique_ptr& argument) { + return argument.get(); + }); + ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs); } template @@ -466,15 +487,24 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); + std::vector starts(3); + std::vector> start_data(3); + for (int i = 0; i < 3; ++i) { + start_data[i] = CreateR0Parameter( + slice_starts[i], i, "slice_starts", &builder, &starts[i]); + } + // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); auto update = ConstantLiteral(&builder, update_values); DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. - ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); + std::vector argument_ptrs; + absl::c_transform(start_data, std::back_inserter(argument_ptrs), + [](const std::unique_ptr& argument) { + return argument.get(); + }); + ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs); } template @@ -518,8 +548,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { XlaOp update; std::unique_ptr update_data = CreateR3Parameter( update_values, 1, "update_values", &builder, &update); - auto starts = ConstantR1(&builder, {index, 0, 0}); - DynamicUpdateSlice(input, update, starts); + auto constant_index = ConstantR0(&builder, index); + auto zero = ConstantR0(&builder, 0); + DynamicUpdateSlice(input, update, {constant_index, zero, zero}); // Run computation and compare against expected values. ComputeAndCompareR3(&builder, expected_values, @@ -720,46 +751,55 @@ void BM_DynamicSlice(int num_iters) { {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); auto input = ConstantLiteral(&builder, input_literal); + auto stream = + client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); + // Create dynamic slice start indices as a parameter: shape [4] - auto start_indices_shape = ShapeUtil::MakeShape(S32, {4}); - auto start_indices = - Parameter(&builder, 0, start_indices_shape, "start_indices"); + auto start_indices_shape = ShapeUtil::MakeShape(S32, {}); + std::vector start_indices(4); + std::vector shaped_buffers; + std::vector host_shapes(4); + for (int i = 0; i < 4; ++i) { + start_indices[i] = + Parameter(&builder, i, start_indices_shape, "start_indices"); + auto start_index_literal = LiteralUtil::CreateR0(i + 1); + // Initialize and transfer parameter buffer. + shaped_buffers.emplace_back( + client->backend() + .transfer_manager() + ->AllocateScopedShapedBuffer(start_indices_shape, &allocator, + /*device_ordinal=*/0) + .ConsumeValueOrDie()); + host_shapes[i] = &shaped_buffers[i].on_host_shape(); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( + stream.get(), start_index_literal, shaped_buffers[i])); + } + // Add DynamicSlice op to the computatation. DynamicSlice(input, start_indices, {1, 1, 1, 1}); auto computation = builder.Build().ConsumeValueOrDie(); - // Initialize and transfer parameter buffer. - auto buffer = client->backend() - .transfer_manager() - ->AllocateScopedShapedBuffer( - start_indices_shape, &allocator, /*device_ordinal=*/0) - .ConsumeValueOrDie(); - - auto start_indices_literal = LiteralUtil::CreateR1({0, 1, 2, 3}); - auto stream = - client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - stream.get(), start_indices_literal, buffer)); - std::unique_ptr executable = - client - ->Compile(computation, {&buffer.on_host_shape()}, - ExecutableBuildOptions()) + client->Compile(computation, host_shapes, ExecutableBuildOptions()) .ConsumeValueOrDie(); // Run some warm-up executions. ExecutableRunOptions options; options.set_allocator(&allocator); const int kWarmups = 2; + std::vector shaped_buffer_ptrs; + absl::c_transform(shaped_buffers, std::back_inserter(shaped_buffer_ptrs), + [](const ScopedShapedBuffer& buffer) { return &buffer; }); + for (int i = 0; i < kWarmups; ++i) { - auto result = executable->Run({&buffer}, options); + auto result = executable->Run(shaped_buffer_ptrs, options); ASSERT_TRUE(result.ok()); } // Run benchmark. tensorflow::testing::StartTiming(); for (int i = 0; i < num_iters; ++i) { - auto result = executable->Run({&buffer}, options); + auto result = executable->Run(shaped_buffer_ptrs, options); ASSERT_TRUE(result.ok()); } } diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index c84973e17b234c24c84f02a369ce0185f5772cca..139d5c70b8cbcf14670abcb064fcca2f0ba853c6 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -40,14 +40,15 @@ class ExhaustiveF32ElementwiseOpTest Literal input_literal = LiteralUtil::CreateFromDimensions(F32, {input_size}); + absl::Span input_arr = input_literal.data(); for (int64 i = begin; i < end; i++) { if (i >= known_incorrect_range.first && i < known_incorrect_range.second) { // If the operation is known to be buggy on a specific input clamp that // input to 0 under the assumption that the op is at least correct on 0. - input_literal.Set({i - begin}, 0.0f); + input_arr[i - begin] = 0; } else { - input_literal.Set({i - begin}, absl::bit_cast(i)); + input_arr[i - begin] = absl::bit_cast(i); } } @@ -60,7 +61,7 @@ class ExhaustiveF32ElementwiseOpTest std::vector expected_result; expected_result.reserve(input_size); for (int64 i = 0; i < input_size; i++) { - expected_result.push_back(evaluate_op(input_literal.Get({i}))); + expected_result.push_back(evaluate_op(input_arr[i])); } ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, diff --git a/tensorflow/compiler/xla/tests/filecheck.cc b/tensorflow/compiler/xla/tests/filecheck.cc index dcb469087e0064d17ce3b04fdeaf0b6136069a55..1b0bebe2d03a9a153cd0c80329ed0c49c91333a3 100644 --- a/tensorflow/compiler/xla/tests/filecheck.cc +++ b/tensorflow/compiler/xla/tests/filecheck.cc @@ -48,7 +48,7 @@ StatusOr RunFileCheck(const string& input, const string& pattern) { tensorflow::SubProcess file_check_process; file_check_process.SetProgram(file_check_path, - {file_check_path, pattern_path}); + {file_check_path, "-v", pattern_path}); file_check_process.SetChannelAction(tensorflow::CHAN_STDIN, tensorflow::ACTION_PIPE); file_check_process.SetChannelAction(tensorflow::CHAN_STDERR, @@ -71,9 +71,7 @@ StatusOr RunFileCheck(const string& input, const string& pattern) { LOG(WARNING) << "NOTE: FileCheck binary does not exist!"; } - LOG(WARNING) << "FileCheck error: " << standard_error; - LOG(WARNING) << "FileCheck input was:"; - XLA_LOG_LINES(tensorflow::WARNING, input); + LOG(WARNING) << "FileCheck error:\n" << standard_error; LOG(WARNING) << "FileCheck pattern was:"; XLA_LOG_LINES(tensorflow::WARNING, pattern); } else if (!standard_error.empty()) { diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index d1fddf9d6b494a822610e41307fa103dc90bdef3..2178c9b3f3d39ac034c59585c6836d2bc59162c1 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -523,10 +523,10 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto dynamic_slice2 = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - ShapeUtil::MakeShape(S32, {2}), const0, const1, {2})); + ShapeUtil::MakeShape(S32, {2}), const0, {const1}, {2})); auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, dynamic_slice2)); hlo_module->AddEntryComputation(builder.Build()) diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index daa89398a697af9149797d621c3bdca80a00aedd..d65b67a535d43553a3a94f76482ad4618f9b8aab 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -600,7 +600,9 @@ ENTRY main { class GatherClientLibraryTest : public ClientLibraryTestBase {}; -XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { +// Disabled on interpreter since ExectuteAsyncOnStream is not supported. +XLA_TEST_F(GatherClientLibraryTest, + DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(Basic))) { // We create this HLO, but using the XlaBuilder API. // // ENTRY main { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index d57846e19bb80c5b9c87d50596da2915f9aef317..66f72ba8d20b8ef1f436da4425b2bb6518ee9a94 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -139,7 +139,8 @@ std::unique_ptr HloTestBase::CreateNewVerifiedModule( const string& name) { return absl::make_unique( name, GetModuleConfigForTest(), verifier_layout_sensitive_, - allow_mixed_precision_in_hlo_verifier_); + allow_mixed_precision_in_hlo_verifier_, + backend().compiler()->ShapeSizeBytesFunction()); } StatusOr> @@ -147,7 +148,8 @@ HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, const HloModuleConfig& config) { auto module = absl::make_unique( TestName(), config, verifier_layout_sensitive_, - allow_mixed_precision_in_hlo_verifier_); + allow_mixed_precision_in_hlo_verifier_, + backend().compiler()->ShapeSizeBytesFunction()); TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); TF_RETURN_IF_ERROR(module->Verify()); return std::move(module); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 1d1e7f437296a7493ef7da07039fcf6d273f35bc..69a4f96288c7285010e9adbdc33f1b394f58d8d2 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -46,10 +46,12 @@ class VerifiedHloModule : public HloModule { public: VerifiedHloModule(const string& name, const HloModuleConfig& config, bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) + bool allow_mixed_precision_in_hlo_verifier, + std::function shape_size_function) : HloModule(name, config), - verifier_(verifier_layout_sensitive, - allow_mixed_precision_in_hlo_verifier) {} + verifier_( + verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier, + /*instruction_can_change_layout_func=*/{}, shape_size_function) {} ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 554eb24d44168caa7d7252015e3d99f2d567df9b..a2fd6070731943f15c773265f428b16f520d02ee 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -86,7 +86,7 @@ void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, /* static */ ::testing::AssertionResult LiteralTestUtil::Near( const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error_spec, bool detailed_message) { + const ErrorSpec& error_spec, absl::optional detailed_message) { return StatusToAssertion(literal_comparison::Near( expected, actual, error_spec, detailed_message, &OnMiscompare)); } @@ -97,7 +97,8 @@ void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, if (error.has_value()) { VLOG(1) << "Expects near"; return StatusToAssertion(literal_comparison::Near( - expected, actual, *error, /*detailed_message=*/false, &OnMiscompare)); + expected, actual, *error, /*detailed_message=*/absl::nullopt, + &OnMiscompare)); } VLOG(1) << "Expects equal"; return StatusToAssertion(literal_comparison::Equal(expected, actual)); diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 43cca91f64b2c0fbfde5054a361cf0f95302c23d..d7cf9bed98a3eb7479b6deb6838dc388a0869360 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -93,7 +93,7 @@ class LiteralTestUtil { static ::testing::AssertionResult Near( const LiteralSlice& expected, const LiteralSlice& actual, const ErrorSpec& error_spec, - bool detailed_message = false) TF_MUST_USE_RESULT; + absl::optional detailed_message = absl::nullopt) TF_MUST_USE_RESULT; // Asserts the given literal are within the given error bound of the given // expected values. Only supported for floating point values. diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index a99b43f4690b3063f76e2cda1e58c9b4ba9a1df4..96527886b718bc1ea4ce8cc2d7dbeb2e3ef1d1eb 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -205,7 +205,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); - EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); + EXPECT_TRUE(result.on_host_shape().IsTuple()); EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape())); Literal result_literal = ShapedBufferToLiteral(result); @@ -233,7 +233,7 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); - EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); + EXPECT_TRUE(result.on_host_shape().IsTuple()); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); Literal result_literal = ShapedBufferToLiteral(result); @@ -311,7 +311,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer}); - EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); + EXPECT_TRUE(result.on_host_shape().IsTuple()); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); Literal result_literal = ShapedBufferToLiteral(result); @@ -842,7 +842,8 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { LiteralUtil::CreateR0(123456789000LL)})); } -XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { +// Disabled on interpreter backend since infeed HLO is unsupported. +XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedTest)) { XlaBuilder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {3}); auto in = Infeed(&builder, shape); @@ -867,7 +868,8 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, result); } -XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) { +// Disabled on interpreter backend since infeed/outfeed HLOs are unsupported. +XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) { XlaBuilder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {3}); auto in = Infeed(&builder, shape); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 3f5135438fc59bea98527b1be30ee49339edd455..1fd9cb055c0bebc0f31496eb82f53a7b7a6cbfba 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -208,9 +208,7 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { ROOT fusion = (s32[]) fusion(x), kind=kLoop, calls=fused_computation } )"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::MakeTupleOwned( LiteralUtil::MakeTupleOwned( LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), @@ -241,9 +239,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { const = f32[4] constant({0, 0, 0, 0}) ROOT select = f32[4] select(gte0, gte1, const) })"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0, -1.0}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, result); @@ -273,9 +269,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { p1 = f32[3] parameter(0) ROOT map = f32[3] map(p1), to_apply=map_computation })"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, result); @@ -315,9 +309,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -346,9 +338,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -378,9 +368,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -410,9 +398,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -443,9 +429,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -478,9 +462,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -513,9 +495,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); auto init1 = LiteralUtil::CreateR0(5); @@ -549,9 +529,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3( {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}}, {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}}); diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 8f2c26f0eea9c7a3b33cd77e5977924c1659535a..e49bcf26bd6e50f8fb36c86f217907b5d4901eae 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -80,7 +80,9 @@ XLA_TEST_F(PrngTest, LargeU01) { UniformTest(0, 1, {0x100, 0x100}); } XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest(5, 24, {12}); } // TODO(b/71543667): Fix Rng ops on LLVM backends. -XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests))) { +// TODO(b/122047800): Interpreter does not support BF16 for RNG ops. +XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER( + DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests)))) { for (int64 seed = 0; seed < 100; ++seed) { // The largest negative number smaller than zero in bf16 that's not // denormalized. @@ -103,7 +105,9 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests))) { } // TODO(b/71543667): Fix Rng ops on LLVM backends. -XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) { +// TODO(b/122047800): Interpreter does not support BF16 for RNG ops. +XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER(DISABLED_ON_GPU( + DISABLED_ON_CPU(ScalarBF16CountTests)))) { // There are 3 BF16 values in the range of [32.25, 33): 32.25, 32.5, 32.75, // they should get similar counts. bfloat16 low = static_cast(32.25); @@ -276,6 +280,39 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6)); } +// This test verifies that the two RNG instructions with the same parameters in +// the same HloComputation produces different values. +XLA_TEST_F(PrngTest, DifferentValuesForIdenticalRngNodesInSameComputation) { + // Build a U[0,1) computation. + auto build_computation = [this]() { + XlaBuilder builder(TestName()); + auto a = RngUniform(ConstantR0(&builder, 0), + ConstantR0(&builder, 100), + ShapeUtil::MakeShape(S32, {10})); + auto b = RngUniform(ConstantR0(&builder, 0), + ConstantR0(&builder, 100), + ShapeUtil::MakeShape(S32, {10})); + Tuple(&builder, {a, b}); + return builder.Build(); + }; + + ExecutionOptions execution_options = execution_options_; + execution_options.set_seed(42); + + Literal result_tuple; + { + TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); + TF_ASSERT_OK_AND_ASSIGN( + result_tuple, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options)); + } + + auto results = result_tuple.DecomposeTuple(); + ASSERT_EQ(results.size(), 2); + + EXPECT_FALSE(LiteralTestUtil::Equal(results[0], results[1])); +} + XLA_TEST_F(PrngTest, TenValuesN01) { XlaBuilder builder(TestName()); RngNormal(ConstantR0(&builder, 0), ConstantR0(&builder, 1), diff --git a/tensorflow/compiler/xla/tests/ptxas_bug_120501638.cc b/tensorflow/compiler/xla/tests/ptxas_bug_120501638.cc new file mode 100644 index 0000000000000000000000000000000000000000..0e5d7db97e88936e7336ed02a5c7a1171254b0cf --- /dev/null +++ b/tensorflow/compiler/xla/tests/ptxas_bug_120501638.cc @@ -0,0 +1,82 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +class PtxasBugTest : public HloTestBase {}; + +// Checks for a bug in ptxas, tracked as Google bug 120501638, and nvidia bug +// 2459377. We never received an explanation of what exactly was going wrong +// here in ptxas. Known-bad in ptxas 10.0.145, known-good in ptxas 10.0.249. +TEST_F(PtxasBugTest, DoIt) { + const char* const kModuleStr = R"( +HloModule test + +add_F32.14 { + lhs.15 = f32[] parameter(0) + rhs.16 = f32[] parameter(1) + ROOT add.17 = f32[] add(lhs.15, rhs.16) +} + +ENTRY testcase { + arg0.1 = f32[2,5,2]{2,1,0} parameter(0) + reshape.2 = f32[2,5,2]{2,1,0} reshape(arg0.1) + constant.3 = f32[] constant(0) + pad.4 = f32[2,6,2]{2,1,0} pad(reshape.2, constant.3), padding=0_0x0_1x0_0 + reshape.5 = f32[2,3,2,2]{3,2,1,0} reshape(pad.4) + transpose.6 = f32[2,2,3,2]{3,0,2,1} transpose(reshape.5), dimensions={2,0,1,3} + reshape.7 = f32[4,3,2]{2,1,0} reshape(transpose.6) + reshape.8 = f32[4,1,3,2]{3,2,1,0} reshape(reshape.7) + transpose.9 = f32[4,2,1,3]{1,3,2,0} transpose(reshape.8), dimensions={0,3,1,2} + convert.10 = f32[4,2,1,3]{1,3,2,0} convert(transpose.9) + constant.12 = f32[] constant(0) + pad.13 = f32[4,2,1,3]{3,2,1,0} pad(convert.10, constant.12), padding=0_0x0_0x0_0x0_0 + constant.11 = f32[] constant(0) + reduce-window.18 = f32[4,2,1,3]{3,2,1,0} reduce-window(pad.13, constant.11), + window={size=1x1x1x1}, to_apply=add_F32.14 + constant.19 = f32[] constant(1) + broadcast.20 = f32[4,2,1,3]{3,2,1,0} broadcast(constant.19), dimensions={} + divide.21 = f32[4,2,1,3]{3,2,1,0} divide(reduce-window.18, broadcast.20) + convert.22 = f32[4,2,1,3]{3,2,1,0} convert(divide.21) + transpose.23 = f32[4,1,3,2]{2,1,3,0} transpose(convert.22), dimensions={0,2,3,1} + reshape.24 = f32[4,3,2]{2,1,0} reshape(transpose.23) + reshape.25 = f32[2,2,3,2]{3,2,1,0} reshape(reshape.24) + transpose.26 = f32[2,3,2,2]{3,1,0,2} transpose(reshape.25), dimensions={1,2,0,3} + reshape.27 = f32[2,6,2]{2,1,0} reshape(transpose.26) + slice.28 = f32[2,5,2]{2,1,0} slice(reshape.27), slice={[0:2], [0:5], [0:2]} + reshape.29 = f32[2,5,2]{2,1,0} reshape(slice.28) + tuple.30 = (f32[2,5,2]{2,1,0}) tuple(reshape.29) + ROOT get-tuple-element.31 = f32[2,5,2]{2,1,0} get-tuple-element(tuple.30), index=0 +})"; + + // Create a module with the true-default flags, not the default-for-testing + // flags. In particular, true-default flags enable unrolling, whereas for + // testing we disable unrolling, and this bug doesn't trigger without + // unrolling. + HloModuleConfig config; + config.set_debug_options(DefaultDebugOptionsIgnoringFlags()); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0.01, 0.01})); +} + +} // anonymous namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 22fe4a2670e2e0e1fedc45036a1ceec19f44e42e..16c67d94c76bcf8984a2b3e4cb092026a6924aeb 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -607,7 +607,10 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, Array4D input(param.base_bounds[0], param.base_bounds[1], param.base_bounds[2], param.base_bounds[3]); - input.FillRandom(0.1f, 0.1f); + // Choose a prime iota length so that each window sees a unique set of + // values. (Technically, the requirement is that the iota length is + // relatively prime to all of the dimensions involved in the reduce-window.) + input.FillRepeatedIota(0, 137); Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; @@ -623,9 +626,9 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); auto reducer = param.reducer; - if (use_bfloat16() && Product(param.window_bounds) > 128) { - // To avoid numerical issues, force the reducer to be kMax for large bf16 - // windows. + if (use_bfloat16()) { + // To avoid numerical issues, force the reducer to be kMax for bf16 + // inputs. reducer = kMax; } @@ -949,16 +952,16 @@ struct R3ReduceWindowTestData { /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{95, 202, 251}, /*window_bounds=*/{95, 202, 251}, /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{999, 57, 3}, /*window_bounds=*/{999, 57, 3}, /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{178, 302, 64}, /*window_bounds=*/{178, 302, 64}, /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{63, 261, 257}, /*window_bounds=*/{63, 261, 257}, /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3}, /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, @@ -1001,17 +1004,19 @@ TEST_P(R3ReduceWindowTest, DoIt) { const float kInitValue = 0.0f; Array3D input(param.base_bounds[0], param.base_bounds[1], param.base_bounds[2]); - input.FillRandom(0.1f, 0.1f); + // Choose a prime iota length so that each window sees a unique set of values. + // (Technically, the requirement is that the iota length is relatively prime + // to all of the dimensions involved in the reduce-window.) + input.FillRepeatedIota(0, 137); Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); auto reducer = param.reducer; if (use_bfloat16()) { input_literal = LiteralUtil::ConvertF32ToBF16(input_literal); - if (Product(param.window_bounds) > 128) { - // To avoid numerical issues, force the reducer to be kMax for large bf16 - // windows. - reducer = kMax; - } + + // To avoid numerical issues, force the reducer to be kMax for bf16 + // inputs. + reducer = kMax; } XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input"); diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h index 7ca99a91635e85cd0888e59ecde31e47fec21844..80a6868485c9162d1cb0de24f0adf3f1c1d2503a 100644 --- a/tensorflow/compiler/xla/tests/test_macros.h +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -79,30 +79,28 @@ string PrependDisabledIfIndicated(const string& test_case_name, // heuristic to decide whether the test case should be disabled, and we // determine whether the test case should be disabled by resolving the (test // case name, test name) in a manifest file. -#define XLA_GTEST_TEST_(test_case_name, test_name, parent_class, parent_id) \ - class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \ - : public parent_class { \ - public: \ - GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \ - \ - private: \ - virtual void TestBody(); \ - static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \ - GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \ - test_name)); \ - }; \ - \ - ::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, \ - test_name)::test_info_ = \ - ::testing::internal::MakeAndRegisterTestInfo( \ - #test_case_name, \ - ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \ - .c_str(), \ - nullptr, nullptr, \ - ::testing::internal::CodeLocation(__FILE__, __LINE__), (parent_id), \ - parent_class::SetUpTestCase, parent_class::TearDownTestCase, \ - new ::testing::internal::TestFactoryImpl); \ +#define XLA_GTEST_TEST_(test_case_name, test_name, parent_class) \ + class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \ + : public parent_class { \ + public: \ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \ + \ + private: \ + virtual void TestBody(); \ + static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \ + GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \ + test_name)); \ + }; \ + \ + ::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, \ + test_name)::test_info_ = \ + ::testing::RegisterTest( \ + #test_case_name, \ + ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \ + .c_str(), \ + nullptr, nullptr, __FILE__, __LINE__, []() -> parent_class* { \ + return new GTEST_TEST_CLASS_NAME_(test_case_name, test_name)(); \ + }); \ void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody() // This is identical to the TEST_F macro from "gtest", but it potentially @@ -111,9 +109,8 @@ string PrependDisabledIfIndicated(const string& test_case_name, // Per usual, you can see what tests are available via --gunit_list_tests and // choose to run tests that have been disabled via the manifest via // --gunit_also_run_disabled_tests. -#define XLA_TEST_F(test_fixture, test_name) \ - XLA_GTEST_TEST_(test_fixture, test_name, test_fixture, \ - ::testing::internal::GetTypeId()) +#define XLA_TEST_F(test_fixture, test_name) \ + XLA_GTEST_TEST_(test_fixture, test_name, test_fixture) // Likewise, this is identical to the TEST_P macro from "gtest", but // potentially disables the test based on the DISABLED_MANIFEST file. diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index eafa48ed7b8cf2bd67fe767ad36082661dbbd66e..95c89b0ba6f29c453abab88e29bca13ee006455a 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -168,7 +169,7 @@ void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine, StatusOr MakeFakeLiteralInternal(const Shape& shape, std::minstd_rand0* engine, bool no_duplicates) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { std::vector elements; for (const Shape& element_shape : shape.tuple_shapes()) { TF_ASSIGN_OR_RETURN( @@ -274,16 +275,9 @@ bool NeedsInitValue(const HloUse& use) { // Generate random values that are constrained to the input_shape minus the // output_shape so as not to produce wrapping slices, for instance. -Literal MakeRandomIndex(absl::Span index_space, - std::minstd_rand0* engine) { - std::vector start_indices(index_space.size()); - if (engine != nullptr) { - for (int i = 0; i < index_space.size(); ++i) { - std::uniform_int_distribution generator(0, index_space[i]); - start_indices[i] = generator(*engine); - } - } - return LiteralUtil::CreateR1(start_indices); +Literal MakeRandomIndex(int64 index_bound, std::minstd_rand0* engine) { + std::uniform_int_distribution generator(0, index_bound); + return LiteralUtil::CreateR0(generator(*engine)); } // Use dataflow analysis on each parameter to see if there are uses that would @@ -300,8 +294,8 @@ std::vector FindConstrainedUses( HloInstruction* instruction = use.instruction; const HloOpcode opcode = instruction->opcode(); const int64 op_num = use.operand_number; - if ((opcode == HloOpcode::kDynamicSlice && op_num == 1) || - (opcode == HloOpcode::kDynamicUpdateSlice && op_num == 2)) { + if ((opcode == HloOpcode::kDynamicSlice && op_num >= 1) || + (opcode == HloOpcode::kDynamicUpdateSlice && op_num >= 2)) { constrained_uses.push_back(instruction); } else if (opcode == HloOpcode::kFusion) { const HloInstruction* const to_analyze = @@ -336,7 +330,7 @@ std::vector FindConstrainedUses( StatusOr CreateLiteralForConstrainedUses( const absl::Span constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { - std::vector index_space; + int64 index_bound = INT64_MAX; bool no_duplicates = false; bool needs_constant = false; ConstantType constant_type = ConstantType::kUnknown; @@ -348,19 +342,16 @@ StatusOr CreateLiteralForConstrainedUses( const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice ? use->shape() : use->operand(1)->shape(); - const int64 rank = ShapeUtil::Rank(indexed_shape); - if (!index_space.empty()) { - TF_RET_CHECK(rank == index_space.size()); - for (int64 i = 0; i < rank; ++i) { - index_space[i] = std::min( - index_space[i], ShapeUtil::GetDimension(indexed_shape, i) - - ShapeUtil::GetDimension(slice_shape, i)); - } - } else { - index_space.resize(rank); - for (int64 i = 0; i < rank; ++i) { - index_space[i] = ShapeUtil::GetDimension(indexed_shape, i) - - ShapeUtil::GetDimension(slice_shape, i); + const int64 first_index = + Cast(use)->first_index_operand_number(); + for (int64 operand = first_index; operand < use->operand_count(); + ++operand) { + if (use->operand(operand) == ¶m) { + index_bound = std::min( + index_bound, + ShapeUtil::GetDimension(indexed_shape, operand - first_index) - + ShapeUtil::GetDimension(slice_shape, + operand - first_index)); } } break; @@ -388,13 +379,14 @@ StatusOr CreateLiteralForConstrainedUses( } int constraint_count = 0; constraint_count += no_duplicates ? 1 : 0; - constraint_count += !index_space.empty() ? 1 : 0; + constraint_count += (index_bound != INT64_MAX) ? 1 : 0; constraint_count += needs_constant ? 1 : 0; if (constraint_count > 1) { return Unimplemented("Conflicting operand generation constraints."); } - if (!index_space.empty()) { - return MakeRandomIndex(index_space, engine); + if (index_bound != INT64_MAX) { + return MakeRandomIndex(index_bound, engine) + .Reshape(param.shape().dimensions()); } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: @@ -459,8 +451,8 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive, std::unique_ptr CreateCanonicalDot(const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) { - CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); - CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); + CHECK_EQ(lhs->shape().rank(), 2); + CHECK_EQ(rhs->shape().rank(), 2); PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( 2, PrecisionConfig::DEFAULT); diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 448a66cfdd897b17cce1c87c050520a2f2eb0ea2..591d6c19228a313f530cdae18f4be37e7b517601 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -79,25 +79,26 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { R"(HloModule index_space_module ENTRY IndexSpace { - index_param = s32[3]{0} parameter(0) - array_param.1 = f32[123,4,789]{0,1,2} parameter(1) - array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) - dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3} - ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2} + index_param.0 = s32[] parameter(0) + index_param.1 = s32[] parameter(1) + index_param.2 = s32[] parameter(2) + array_param.1 = f32[123,4,789]{0,1,2} parameter(3) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(4) + dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param.0, index_param.1, index_param.2), dynamic_slice_sizes={1,2,3} + ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param.0, index_param.1, index_param.2), dynamic_slice_sizes={3,2,2} })") .ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); - ASSERT_EQ(args.size(), 3); - const Literal& index_arg = args[0]; + ASSERT_EQ(args.size(), 5); - EXPECT_EQ(index_arg.Get({0}), 0); + EXPECT_EQ(args[0].Get({}), 0); - EXPECT_GE(index_arg.Get({1}), 0); - EXPECT_LE(index_arg.Get({1}), 2); + EXPECT_GE(args[1].Get({}), 0); + EXPECT_LE(args[0].Get({}), 2); - EXPECT_GE(index_arg.Get({2}), 0); - EXPECT_LE(index_arg.Get({2}), 3); + EXPECT_GE(args[2].Get({}), 0); + EXPECT_LE(args[2].Get({}), 3); } XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { @@ -105,28 +106,29 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { R"(HloModule index_space_module ENTRY IndexSpace { - index_param = s32[3]{0} parameter(0) - array_param.1 = f32[123,4,789]{0,1,2} parameter(1) - array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) - update_param.1 = f32[1,2,3]{0,1,2} parameter(3) - update_param.2 = f32[3,2,2]{0,1,2} parameter(4) - - dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param) - ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param) + index_param.0 = s32[] parameter(0) + index_param.1 = s32[] parameter(1) + index_param.2 = s32[] parameter(2) + array_param.1 = f32[123,4,789]{0,1,2} parameter(3) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(4) + update_param.1 = f32[1,2,3]{0,1,2} parameter(5) + update_param.2 = f32[3,2,2]{0,1,2} parameter(6) + + dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param.0, index_param.1, index_param.2) + ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param.0, index_param.1, index_param.2) })") .ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); - ASSERT_EQ(args.size(), 5); - const Literal& index_arg = args[0]; + ASSERT_EQ(args.size(), 7); - EXPECT_EQ(index_arg.Get({0}), 0); + EXPECT_EQ(args[0].Get({}), 0); - EXPECT_GE(index_arg.Get({1}), 0); - EXPECT_LE(index_arg.Get({1}), 2); + EXPECT_GE(args[1].Get({}), 0); + EXPECT_LE(args[0].Get({}), 2); - EXPECT_GE(index_arg.Get({2}), 0); - EXPECT_LE(index_arg.Get({2}), 3); + EXPECT_GE(args[2].Get({}), 0); + EXPECT_LE(args[2].Get({}), 3); } XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) { @@ -198,5 +200,33 @@ ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,14 } } +XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsR0InputToDynamicSlice) { + auto module = ParseHloString(R"( +HloModule Test + +ENTRY %module (parameter.0: s32[], parameter.1: f32[20,20]) -> f32[] { + %parameter.1 = f32[20,20]{1,0} parameter(1) + %constant.1 = s32[1]{0} constant({0}) + %parameter.0 = s32[] parameter(0) + %bitcast.3 = s32[1]{0} bitcast(s32[] %parameter.0) + %concatenate.1 = s32[2]{0} concatenate(s32[1]{0} %constant.1, s32[1]{0} %bitcast.3), dimensions={0} + %dynamic-slice.2 = f32[20,1]{1,0} dynamic-slice(f32[20,20]{1,0} %parameter.1, s32[2]{0} %concatenate.1), dynamic_slice_sizes={20,1} + %bitcast.4 = f32[20]{0} bitcast(f32[20,1]{1,0} %dynamic-slice.2) + %dynamic-slice.3 = f32[1]{0} dynamic-slice(f32[20]{0} %bitcast.4, s32[1]{0} %bitcast.3), dynamic_slice_sizes={1} + ROOT %bitcast.5 = f32[] bitcast(f32[1]{0} %dynamic-slice.3) +} +)") + .ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN(std::vector args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 2); + EXPECT_TRUE(ShapeUtil::Equal(args[0].shape(), ShapeUtil::MakeShape(S32, {}))) + << ShapeUtil::HumanString(args[0].shape()); + EXPECT_TRUE( + ShapeUtil::Equal(args[1].shape(), ShapeUtil::MakeShape(F32, {20, 20}))) + << ShapeUtil::HumanString(args[1].shape()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 9c586bdeb05afb7378e92caed1f3edc408e051bf..cdf2c34fcc3cc005e84626c39c8ab301a9040529 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -176,8 +176,9 @@ XLA_TEST_F(TupleTest, AddTupleElements) { {2.f, 4.f, 6.f}, // row 0 {5.f, 7.f, 9.f}, // row 1 }); - ASSERT_TRUE(ShapeUtil::ShapeIs(vector_shape, F32, {3})); - ASSERT_TRUE(ShapeUtil::ShapeIs(matrix_shape, F32, {/*y=*/2, /*x=*/3})); + ASSERT_TRUE(ShapeUtil::Equal(vector_shape, ShapeUtil::MakeShape(F32, {3}))); + ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, + ShapeUtil::MakeShape(F32, {/*y=*/2, /*x=*/3}))); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } @@ -512,8 +513,7 @@ XLA_TEST_F(TupleTest, ComplexTuples) { class TupleHloTest : public HloTestBase {}; -// Disabled on the interpreter because bitcast doesn't exist on the interpreter. -XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { +XLA_TEST_F(TupleHloTest, BitcastAfterGTE) { const char* testcase = R"( HloModule m, is_scheduled=true @@ -525,9 +525,7 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { ROOT tuple.4 = (f32[1,3]{1,0}) tuple(copy) } )"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({1, 2, 3})); auto result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -559,9 +557,7 @@ XLA_TEST_F(TupleHloTest, ROOT outfeed = token[] outfeed(tuple, token0) } )"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param0 = LiteralUtil::CreateR1({1, 2}); auto param1 = LiteralUtil::CreateR1({2, 3}); auto param4 = LiteralUtil::CreateR0(false); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 6d5f276e82087cedc356691b0ff08df24cec8d20..85212fa56d71088156d2f3edda17f71cdab56da2 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -861,7 +861,7 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Update. auto update = ConvertElementType(Broadcast(out0, {2}), F32); // Starts = iteration * 2; - auto starts = Reshape(Mul(iteration, ConstantR0(&builder, 2)), {1}); + auto starts = Mul(iteration, ConstantR0(&builder, 2)); // UpdateSlice. auto out1 = DynamicUpdateSlice(input, update, starts); @@ -901,7 +901,7 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Per backend the values generated can be different as the different backends // use different random number generators. // TODO(b/32240857): Extend test to verify outputs. -XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { +XLA_TEST_F(WhileTest, WhileWithPrngScalarResult) { auto v6s32 = ShapeUtil::MakeShape(S32, {6}); // Create a computation for the condition: repeat for count iterations. @@ -1146,7 +1146,7 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { // while (f(result).get<0>()) { // result = result + 1; // } -XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { +XLA_TEST_F(WhileTest, WhileWithCallInsideCondition) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. @@ -1299,9 +1299,9 @@ void BM_WhileLoop(int num_iters) { auto one = ConstantR0(&builder, 1.0); auto update = Broadcast(one, {1, 1024, 1024}); // Starts = iteration * 2; - auto starts = ConstantR1(&builder, {0, 0, 0}); + auto zero = ConstantR0(&builder, 0); // UpdateSlice. - auto out1 = DynamicUpdateSlice(input, update, starts); + auto out1 = DynamicUpdateSlice(input, update, {zero, zero, zero}); Tuple(&builder, {out0, out1}); body = builder.Build().ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index e57d072a0632b492b8b6e34439f4e80332b843b6..c7337e8caae8f2ee25f4b25dc22439e08d2ecc25 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -174,9 +174,8 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, exec_run_options.set_allocator(backend->memory_allocator()); exec_run_options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); - ServiceExecutableRunOptions run_options( - exec_run_options, /*borrow_stream=*/nullptr, - backend->eigen_intra_op_thread_pool()); + ServiceExecutableRunOptions run_options(exec_run_options, + /*borrow_stream=*/nullptr); std::vector args = {&lhs_arg, &rhs_arg}; TF_ASSERT_OK_AND_ASSIGN( auto execution_result, @@ -225,14 +224,17 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { line_no++; // Skip 'Execution profile for ....' + ASSERT_LT(line_no, profile_output_lines.size()); TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++], /*expect_hlo=*/false, &parsed_profile_lines)); + ASSERT_LT(line_no, profile_output_lines.size()); TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++], /*expect_hlo=*/true, &parsed_profile_lines)); + ASSERT_LT(line_no, profile_output_lines.size()); TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++], /*expect_hlo=*/true, &parsed_profile_lines)); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 99b32c19a52bf2a1f02047a1ceea626947d994fc..52fee4770ab940741723514d742e998b25765f24 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -29,33 +29,6 @@ tf_cc_binary( ], ) -cc_library( - name = "dumped_computation_to_graphviz_library", - srcs = ["dumped_computation_to_graphviz.cc"], - deps = [ - "//tensorflow/compiler/xla:debug_options_flags", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/core:lib", - "@com_google_absl//absl/types:span", - ], -) - -tf_cc_binary( - name = "dumped_computation_to_graphviz", - deps = [ - ":dumped_computation_to_graphviz_library", - "//tensorflow/compiler/xla/service:interpreter_plugin", - ], -) - tf_cc_binary( name = "show_signature", srcs = ["show_signature.cc"], @@ -95,6 +68,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service/gpu:infeed_manager", + "//tensorflow/compiler/xla/service/gpu:outfeed_manager", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", @@ -281,3 +255,9 @@ tf_cc_binary( "@com_google_absl//absl/strings", ], ) + +sh_test( + name = "interactive_graphviz_build_only_test", + srcs = ["interactive_graphviz_test.sh"], + data = [":interactive_graphviz"], +) diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc deleted file mode 100644 index b623556468fb4a5d96be614b6c067d5a1df51a6f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Usage: dumped_computation_to_graphviz some_binary_snapshot_proto* -// -// Dumps a graphviz URL for a snapshot computation to the command line. -// -// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from -// ServiceInterface::SnapshotComputation to disk. -// -// The GraphViz URL is placed into the log stderr, whereas computation -// statistics are printed on stdout (implementation note: getting computation -// statistics is how we trigger compilation to split out a GraphViz URL). - -#include -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/debug_options_flags.h" -#include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { -namespace tools { - -void RealMain(absl::Span args) { - Client* client = ClientLibrary::LocalClientOrDie(); - for (char* arg : args) { - HloSnapshot module; - TF_CHECK_OK( - tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - XlaComputation computation = - client->LoadSnapshot(module).ConsumeValueOrDie(); - DebugOptions debug_options = GetDebugOptionsFromFlags(); - debug_options.set_xla_generate_hlo_graph(".*"); - ComputationStats stats = - client->GetComputationStats(computation, debug_options) - .ConsumeValueOrDie(); - fprintf(stdout, ">>> %s :: %s\n", arg, stats.DebugString().c_str()); - } -} - -} // namespace tools -} // namespace xla - -int main(int argc, char** argv) { - std::vector flag_list; - xla::AppendDebugOptionsFlags(&flag_list); - xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); - const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); - if (!parse_result) { - LOG(ERROR) << "\n" << usage; - return 2; - } - tensorflow::port::InitMain(argv[0], &argc, &argv); - - absl::Span args(argv, argc); - args.remove_prefix(1); // Pop off the binary name, argv[0] - xla::tools::RealMain(args); - return 0; -} diff --git a/tensorflow/compiler/xla/tools/hlo_extractor_test.cc b/tensorflow/compiler/xla/tools/hlo_extractor_test.cc index c187222a11ee721b006194a68620c58749707193..4beb099b330cadf4540944979f38681bae07103c 100644 --- a/tensorflow/compiler/xla/tools/hlo_extractor_test.cc +++ b/tensorflow/compiler/xla/tools/hlo_extractor_test.cc @@ -36,9 +36,8 @@ ENTRY %entry { } )"; - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_string)); { auto extracted_module = @@ -75,9 +74,8 @@ ENTRY %entry { } )"; - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_string)); { auto extracted_module = @@ -120,9 +118,8 @@ ENTRY %entry { } )"; - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_string)); { auto extracted_module = diff --git a/tensorflow/compiler/xla/tools/interactive_graphviz.cc b/tensorflow/compiler/xla/tools/interactive_graphviz.cc index 6c90cde5a75a93837ee149fd9b5a60e6413c2ac4..ac865707f8697e0b94173a2a33e7be52a9564867 100644 --- a/tensorflow/compiler/xla/tools/interactive_graphviz.cc +++ b/tensorflow/compiler/xla/tools/interactive_graphviz.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,8 +29,7 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" -#include "absl/strings/string_view_utils.h" -#include "absl/strings/util.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -56,7 +55,8 @@ bool ReadLine(const char *prompt, string *line) { return util::ReadLine(prompt, line); #else std::cout << prompt; - return std::getline(std::cin, *line); + std::getline(std::cin, *line); + return std::cin.good(); #endif } @@ -391,9 +391,9 @@ void DisplayGraphHandle(const Options &opts, const string& handle) { std::cout << handle << std::endl; // If it is a url, try to open it up in the user's browser too. - if (strings::StartsWithIgnoreCase(handle, "http://") || - strings::StartsWithIgnoreCase(handle, "https://") || - strings::StartsWithIgnoreCase(handle, "file://")) { + if (absl::StartsWithIgnoreCase(handle, "http://") || + absl::StartsWithIgnoreCase(handle, "https://") || + absl::StartsWithIgnoreCase(handle, "file://")) { const char* browser_bin = opts.browser.empty() ? "/usr/bin/sensible-browser" : opts.browser.c_str(); tensorflow::SubProcess p; @@ -515,7 +515,7 @@ void InteractiveDumpGraphs(const Options& opts, const HloModule& module) { << std::endl; continue; } - std::vector tokens = strings::Split(line, ' '); + std::vector tokens = absl::StrSplit(line, ' '); if (tokens[0] == "quit" || tokens[0] == "exit") { break; } else if (tokens[0] == "help") { diff --git a/tensorflow/compiler/xla/tools/interactive_graphviz_test.sh b/tensorflow/compiler/xla/tools/interactive_graphviz_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..b3e43aa7da062547fb5f187b885e997fc44bbb65 --- /dev/null +++ b/tensorflow/compiler/xla/tools/interactive_graphviz_test.sh @@ -0,0 +1,19 @@ +#! /bin/bash +# /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ==============================================================================*/ + +# This is a placeholder for a compile-only test for intractive_graphviz tool. + +exit 0 diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 27a8dd13308b29da9a5013ac9f696613981d68bb..c01a47b510c0e4252e350960b995643b39b70d4a 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" +#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -73,7 +74,17 @@ namespace { // fields. struct Options { string fake_infeed_shape; - bool generate_fake_infeed = false; + string fake_outfeed_shape; + + // generate_fake_infeed == true is a safe default: If the model has 0 or 1 + // infeeds, then it will work like normal. If the model has more than one + // infeed, it will be an error, but that wouldn't have worked anyway if you + // hadn't passed generate_fake_infeed. + // + // Same for generate_fake_outfeed. + bool generate_fake_infeed = true; + bool generate_fake_outfeed = true; + bool use_fake_data = false; bool print_result = true; int num_runs = 1; @@ -96,6 +107,83 @@ std::unique_ptr CompileExecutable(const HloSnapshot& module, .ValueOrDie(); } +absl::optional GetXfeedShape(bool is_infeed, + const HloModuleProto& module, + const Options& opts) { + std::vector xfeed_instrs; + for (const auto& comp : module.computations()) { + for (const auto& instruction : comp.instructions()) { + if (instruction.opcode() == HloOpcodeString(is_infeed + ? HloOpcode::kInfeed + : HloOpcode::kOutfeed)) { + xfeed_instrs.push_back(instruction); + } + } + } + + auto log_xfeed_instrs = [&] { + for (const auto& infeed : xfeed_instrs) { + LOG(ERROR) << " " << ShapeUtil::HumanString(Shape(infeed.shape())) << " " + << infeed.name(); + } + }; + + auto find_instruction_from_id_or_die = [&](int64 id) { + for (const auto& comp : module.computations()) { + for (const auto& instruction : comp.instructions()) { + if (instruction.id() == id) { + return instruction; + } + } + } + LOG(FATAL) << "No instruction with id " << id; + }; + + absl::optional xfeed_shape; + string xfeed_name = is_infeed ? "infeed" : "outfeed"; + string fake_xfeed_shape = + is_infeed ? opts.fake_infeed_shape : opts.fake_outfeed_shape; + bool generate_fake_xfeed = + is_infeed ? opts.generate_fake_infeed : opts.generate_fake_outfeed; + if (!fake_xfeed_shape.empty()) { + xfeed_shape = std::move(ParseShape(fake_xfeed_shape)).ValueOrDie(); + } else if (generate_fake_xfeed) { + CHECK_LT(xfeed_instrs.size(), 2) + << "--generate_fake_" << xfeed_name + << " only works if the model has 0 or 1 " << xfeed_name << " ops."; + if (xfeed_instrs.empty()) { + LOG(INFO) << "Not generating fake " << xfeed_name + << " shape; model has no " << xfeed_name << "s."; + } else if (xfeed_instrs.size() == 1) { + // kInfeed instructions should have a shape (buffer, token). kOutfeed + // instructions should have operand 0 of shape `buffer`. We want to xfeed + // just `buffer`. + xfeed_shape = is_infeed + ? Shape(xfeed_instrs.front().shape()).tuple_shapes(0) + : Shape(find_instruction_from_id_or_die( + xfeed_instrs.front().operand_ids(0)) + .shape()); + LOG(INFO) << "Generating fake " << xfeed_name << " with inferred shape: " + << ShapeUtil::HumanString(*xfeed_shape); + } else { + LOG(ERROR) << "--generate_fake_" << xfeed_name + << " only works if the model has 0 or 1 " << xfeed_name + << " ops, but this model has " << xfeed_instrs.size() + << " of them:"; + log_xfeed_instrs(); + LOG(FATAL) << "Can't run model with --generate_fake_infeed."; + } + } else if (!xfeed_instrs.empty()) { + LOG(ERROR) << "Model contains " << xfeed_instrs.size() << " " << xfeed_name + << " instruction(s), but neither --generate_fake_" << xfeed_name + << " nor --fake_" << xfeed_name + << "_shape was specified. Execution will likely hang."; + log_xfeed_instrs(); + } + + return xfeed_shape; +} + // Invokes the given computation passing arbitrary data for every (unbound) // parameter if use_fake_data, Otherwise use recorded data if available. // @@ -142,54 +230,37 @@ StatusOr ReplayComputation(const HloSnapshot& module, } } - bool provide_infeed = false; - Shape infeed_shape; - if (!opts.fake_infeed_shape.empty()) { - StatusOr shape_status = ParseShape(opts.fake_infeed_shape); - TF_CHECK_OK(shape_status.status()); - infeed_shape = std::move(shape_status).ValueOrDie(); - provide_infeed = true; - } else if (opts.generate_fake_infeed) { - for (const auto& comp : computation.proto().computations()) { - for (const auto& instruction : comp.instructions()) { - if (instruction.opcode() == HloOpcodeString(HloOpcode::kInfeed)) { - CHECK(!provide_infeed) - << "--generate_fake_infeed only works if the model has 0 or 1 " - "infeed ops, but this one has >= 2."; - provide_infeed = true; - infeed_shape = Shape(instruction.shape()); - LOG(INFO) << "Generating fake infeed shape for inferred shape: " - << ShapeUtil::HumanString(infeed_shape); - } - } - } + if (absl::optional infeed_shape = GetXfeedShape( + /*is_infeed=*/true, computation.proto(), opts)) { + auto infeed_data = std::make_shared( + std::move(MakeFakeLiteral(*infeed_shape)).ValueOrDie()); + xla::gpu::GetOrCreateInfeedManager() + ->RegisterBeforeGetNextDestinationCallback([infeed_data, client] { + TF_CHECK_OK(client->TransferToInfeed(*infeed_data)); + }); } - // We only instantiate the thread pool if the user has requested that a - // concurrent infeed occur via the fake_infeed_shape, or when - // --generate_fake_infeed is passed and there exists an infeed operation in - // the HloSnapshot. - absl::optional pool; - Literal data; - if (provide_infeed) { - data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie(); - } - auto transfer_infeed = [&data, client]() { - TF_CHECK_OK(client->TransferToInfeed(data)); - }; - if (provide_infeed) { - pool.emplace(tensorflow::Env::Default(), "infeed", - /*num_threads=*/1); - pool->Schedule([transfer_infeed]() { - // There may be several infeed buffers needed, however we don't know how - // many. If we proactively transfer too many infeed buffers, we may run - // out of memory. If we transfer too few infeed buffers, the program will - // hang. Therefore, we register a callback that is called when the infeed - // becomes empty, and in this callback we will transfer another fake - // infeed. - auto infeed_manager = xla::gpu::GetOrCreateInfeedManager(); - infeed_manager->RegisterOnEmptyCallback(transfer_infeed); - transfer_infeed(); - }); + + absl::optional outfeed_thread_pool; + if (absl::optional outfeed_shape = GetXfeedShape( + /*is_infeed=*/false, computation.proto(), opts)) { + // For each an outfeed that runs, enqueue a task that will consume it. We + // need a thread pool because the act of running an outfeed blocks on there + // being a destination available, and the act of making a destination + // available blocks on there being outfeed data available. + outfeed_thread_pool.emplace(tensorflow::Env::Default(), "infeed", + /*num_threads=*/1); + auto consume_outfeed = [client, outfeed_shape] { + TF_CHECK_OK( + client->TransferFromOutfeedLocal(*outfeed_shape, /*device_ordinal=*/0) + .status()); + VLOG(1) << "Received outfeed data of shape " + << ShapeUtil::HumanStringWithLayout(*outfeed_shape); + }; + xla::gpu::GetOrCreateOutfeedManager() + ->RegisterBeforeGetNextDestinationCallback( + [consume_outfeed, &outfeed_thread_pool] { + outfeed_thread_pool->Schedule(consume_outfeed); + }); } // Do not attempt to run the executable if num_runs is less than 1. @@ -304,8 +375,10 @@ int RealMain(absl::Span args, const Options& opts) { for (int64 i = 0; i < executables.size(); ++i) { LocalExecutable* executable = executables[i].get(); + LOG(ERROR) << "Running iteration " << i; StatusOr result_status = ReplayComputation(snapshots[i], executable, client, opts); + LOG(ERROR) << "iteration complete."; if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", args[i], result_status.status().ToString().c_str()); @@ -350,9 +423,14 @@ int main(int argc, char** argv) { "Number of times to run each computation"), tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape, "Shape of fake data to construct for (infinite) infeed"), + tensorflow::Flag("fake_outfeed_shape", &opts.fake_outfeed_shape, + "Shape of fake data to outfeed from computation"), tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed, - "Whether a fake infeed shape should be generated " - "derived from the computation"), + "Whether a fake infeed shape should be derived " + "from the computation"), + tensorflow::Flag("generate_fake_outfeed", &opts.generate_fake_outfeed, + "Whether a fake outfeed shape should be derived " + "from the computation"), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index b645acb700b0f168112a40c9c72b4669435f717d..daf678f69017b9eb86cbc464a1f33b434021901d 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -41,6 +41,7 @@ using ::tensorflow::uint32; using ::tensorflow::uint64; using complex64 = std::complex; +using complex128 = std::complex; using ::Eigen::half; diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 68cab7387cf1576072f96878b50f07def6862d8b..34b73b5206fa20d6dff7567afd78fd89897c8c33 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -86,7 +86,7 @@ bool IsPermutation(absl::Span permutation, int64 rank) { CHECK_LT(index, rank); output[index] = 0; } - return std::find(output.begin(), output.end(), -1) == output.end(); + return !absl::c_linear_search(output, -1); } std::vector InversePermutation( diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 6722641e9d2c177440361e6f0d1f6c0804eb7cda..f2fd17dc99455a921bf875aad2a3661b4d456823 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -324,8 +324,7 @@ bool IsIdentityPermutation(absl::Span permutation); template int64 PositionInContainer(const Container& container, int64 value) { - return std::distance(container.begin(), - std::find(container.begin(), container.end(), value)); + return std::distance(container.begin(), absl::c_find(container, value)); } // Formats the container as a comma-separated string. StrAppend must support diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 51c73b3d17e4c32d9a8a14d3055ab56f02922af3..e001cc35f9fcea2783b3952e825838af6bbece72 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -137,25 +138,23 @@ bool HasPadding(const Window& window) { } bool HasSymmetricPadding(const Window& window) { - return std::all_of(window.dimensions().begin(), window.dimensions().end(), - [](const WindowDimension& dim) { - return dim.padding_low() == dim.padding_high(); - }); + return absl::c_all_of(window.dimensions(), [](const WindowDimension& dim) { + return dim.padding_low() == dim.padding_high(); + }); } bool HasSymmetricPadding(const PaddingConfig& padding_config) { - return std::all_of(padding_config.dimensions().begin(), - padding_config.dimensions().end(), - [](const PaddingConfig::PaddingConfigDimension& dim) { - return dim.edge_padding_low() == dim.edge_padding_high(); - }); + return absl::c_all_of(padding_config.dimensions(), + [](const PaddingConfig::PaddingConfigDimension& dim) { + return dim.edge_padding_low() == + dim.edge_padding_high(); + }); } bool HasNegativePadding(const Window& window) { - return std::any_of(window.dimensions().begin(), window.dimensions().end(), - [](const WindowDimension& dim) { - return dim.padding_low() < 0 || dim.padding_high() < 0; - }); + return absl::c_any_of(window.dimensions(), [](const WindowDimension& dim) { + return dim.padding_low() < 0 || dim.padding_high() < 0; + }); } bool HasBaseDilation(const Window& window) { @@ -190,10 +189,9 @@ bool AllOrNoneReversed(const Window& window) { return true; } bool reversed = window.dimensions()[0].window_reversal(); - return std::all_of(window.dimensions().begin(), window.dimensions().end(), - [&](const WindowDimension& dim) { - return dim.window_reversal() == reversed; - }); + return absl::c_all_of(window.dimensions(), [&](const WindowDimension& dim) { + return dim.window_reversal() == reversed; + }); } bool HasDilation(const Window& window) { diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index 1439f1bcc5cec39203a7cb4b1f8604e7349382c6..60adea5a4a242e5843b41927ba77c197e8fac444 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -1,30 +1,40 @@ """Wrapper around cc_proto_library used inside the XLA codebase.""" -load("//tensorflow/core:platform/default/build_config.bzl", - "cc_proto_library") -load("//tensorflow/core:platform/default/build_config_root.bzl", - "if_static") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "cc_proto_library", +) +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "if_static", +) +load("//tensorflow:tensorflow.bzl", "if_cuda_is_configured") # xla_proto_library() is a convenience wrapper around cc_proto_library. -def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0, **kwargs): - if kwargs.get('use_grpc_plugin'): - kwargs['use_grpc_namespace'] = True - cc_proto_library(name=name, - srcs=srcs, - deps=deps, - cc_libs = if_static( - ["@protobuf_archive//:protobuf"], - otherwise=["@protobuf_archive//:protobuf_headers"], - ), - protoc="@protobuf_archive//:protoc", - testonly=testonly, - visibility=visibility, - **kwargs) +def xla_proto_library(name, srcs = [], deps = [], visibility = None, testonly = 0, **kwargs): + if kwargs.get("use_grpc_plugin"): + kwargs["use_grpc_namespace"] = True + cc_proto_library( + name = name, + srcs = srcs, + deps = deps, + cc_libs = if_static( + ["@protobuf_archive//:protobuf"], + otherwise = ["@protobuf_archive//:protobuf_headers"], + ), + protoc = "@protobuf_archive//:protoc", + testonly = testonly, + visibility = visibility, + **kwargs + ) def xla_py_grpc_library(**kwargs): - # Note: we don't currently define any special targets for Python GRPC in OSS. - _ignore = kwargs - pass - + # Note: we don't currently define any special targets for Python GRPC in OSS. + _ignore = kwargs + pass ORC_JIT_MEMORY_MAPPER_TARGETS = [] + +# We link the GPU plugin into the XLA Python extension if CUDA is enabled. +def xla_python_default_plugins(): + return if_cuda_is_configured(["//tensorflow/compiler/xla/service:gpu_plugin"]) diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 0e8fa73f8170addfa5061b33f3d6882a13890bce..92834dbb02cdcd6383ceec3ffd079834b163ee6a 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -230,7 +230,11 @@ message DebugOptions { // Enable fast math with eigen in the HLO evaluator. bool xla_hlo_evaluator_use_fast_path = 106; - // Next id: 107 + // Temporary option to allow support for both the R1 and the scalar index + // versions of DynamicSlice and DynamicUpdateSlice. Only used for testing. + bool xla_allow_scalar_index_dynamic_ops = 107; + + // Next id: 108 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. @@ -261,6 +265,10 @@ message ExecutionOptions { // computation on. The computation will be partitioned across these devices. // If not provided, the default device will be chosen. repeated DeviceHandle device_handles = 5; + + // Number of replicas of the computation to run. If zero, uses the default + // number of replicas for the XLA service. + int32 num_replicas = 6; } message GetDeviceHandlesRequest { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index e9c86abe5094244988d3465ef7c949509deaec37..a64e2f5df5cacca05e83f31c941c57abd5ccf4de 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -56,6 +56,7 @@ enum PrimitiveType { // Complex values of fixed width. C64 = 15; // Paired F32 (real, imag), as in std::complex. + C128 = 18; // Paired F64 (real, imag), as in std::complex. // A tuple is a polymorphic sequence; e.g. a shape that holds different // sub-shapes. They are used for things like returning multiple values from a @@ -75,7 +76,7 @@ enum PrimitiveType { // primitive type will have empty dimensions and tuple_shapes fields. TOKEN = 17; - // Next = 18 + // Next = 19 } // Describes the padding configuration for Pad operation. The padding amount on @@ -188,11 +189,14 @@ message ShapeProto { // The element type for this shape. PrimitiveType element_type = 2; - // The size (number of elements) for each dimension. - // In XLA, dimensions are numbered from 0 to N-1 for an - // N-dimensional array. The first element of 'dimensions' is the size of - // dimension 0, the second element is the size of dimension 1, and so forth. - // Empty list indicates a scalar. + // The size (number of elements) for each dimension, or an upper bound on the + // size if the dimension is dynamic. In XLA, dimensions are numbered from 0 + // to N-1 for an N-dimensional array. The first element of 'dimensions' is the + // size of dimension 0, the second element is the size of dimension 1, and so + // forth. Empty list indicates a scalar. + // + // If the respective element in 'is_dimension_dynamic' is true then the value + // in this field represents an upper bound on the size of the dimension. repeated int64 dimensions = 3; // For tuples only, the shapes of constitutent shapes in the tuple sequence. @@ -201,6 +205,12 @@ message ShapeProto { // The layout used to back this shape. LayoutProto layout = 5; + // For arrays, this indicates whether or not each dimension is + // dynamically-sized. The number of elements in this repeated field should be + // zero (indicating that no dimensions are dynamic) or equal to the number of + // elements in the 'dimensions' field. + repeated bool is_dynamic_dimension = 6; + // Important: if any field is added, be sure to modify ShapeUtil::Equal(), // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for // the new field. @@ -358,6 +368,7 @@ message LiteralProto { repeated float f32s = 8; repeated double f64s = 9; repeated float c64s = 12; // Stored as interleaved real, imag floats. + repeated double c128s = 18; // Stored as interleaved real, imag doubles. repeated LiteralProto tuple_literals = 10; // The F16s, BF16s, U16s and S16s are encoded in little endian byte order bytes f16s = 11; @@ -365,7 +376,7 @@ message LiteralProto { bytes u16s = 16; bytes s16s = 17; repeated int64 sparse_indices = 14; - // Next = 18 + // Next = 19 } message WindowDimension { diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD index 67f475846e5f16060c1080759b0acb4216c4e72b..dc02fd272fd8700c7f8fa64adf7ab57c88bab706 100644 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ b/tensorflow/compiler/xrt/kernels/BUILD @@ -11,20 +11,15 @@ cc_library( name = "xrt_state_ops", hdrs = ["xrt_state_ops.h"], deps = [ + "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt:xrt_utils", "//tensorflow/core:core_cpu_internal", @@ -55,6 +50,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt:xrt_utils", "//tensorflow/core:core_cpu_internal", @@ -62,7 +58,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/stream_executor:stream_executor_headers_lib", + "//tensorflow/stream_executor:stream_executor_headers", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index 2ccdf0f02d840600d5e0649c4805e3672d4a1286..2ee1a6cd1aebcdbd65892b33e5044489070ab5c4 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -215,11 +215,6 @@ XRTReleaseCompilationRefOp::~XRTReleaseCompilationRefOp() = default; void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) { VLOG(1) << "XRTReleaseCompilationRefOp::Compute"; - const Tensor& key_tensor = ctx->input(0); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(key_tensor.shape()), - errors::Internal("computation key should be a string scalar")); - int64 uid = key_tensor.scalar()(); - ResourceMgr* rm; OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm)); @@ -230,9 +225,13 @@ void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) { kXRTCompilationCacheResourceName, &cache)); core::ScopedUnref cache_unref(cache); - OP_REQUIRES_OK(ctx, cache->Release(uid)); - - VLOG(2) << "Released computation handle " << uid; + const Tensor& keys_tensor = ctx->input(0); + auto flat_keys = keys_tensor.flat(); + for (int64 i = 0; i < flat_keys.size(); ++i) { + int64 key = flat_keys(i); + OP_REQUIRES_OK(ctx, cache->Release(key)); + VLOG(2) << "Released computation handle " << key; + } } } // namespace diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 751329eefc33f3372335c805233dafabbf42bf36..116c193cab65410a5a7c3058f98cc2be2cbe9e67 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -228,8 +229,27 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( shaped_buffer, device_ref.backend(), device_ref.device_ordinal(), &output_tuple)); + + // The ScopedShapedBuffer returned by the executable Run() API, in case of + // input/output buffer aliasing, might have holes in it, which need to be + // filled using the proper input tuples buffers which are the source of + // aliasing. + const xla::HloInputOutputAliasConfig& input_output_alias = + executable->executable()->module().input_output_alias_config(); + auto alias_function = + [&](const xla::ShapeIndex& output_index, + const xla::HloInputOutputAliasConfig::Alias& alias) -> Status { + TF_RET_CHECK(alias.parameter_number < input_tuples.size()); + return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias + ? output_tuple->AliasBufferFrom( + *input_tuples[alias.parameter_number], + alias.parameter_index, output_index) + : Status::OK(); + }; + TF_RETURN_IF_ERROR(input_output_alias.ForEachAliasWithStatus(alias_function)); + if (config_proto.return_exploded_tuple() && - xla::ShapeUtil::IsTuple(output_tuple->on_device_shape())) { + output_tuple->on_device_shape().IsTuple()) { int64 tuple_element_count = xla::ShapeUtil::TupleElementCount(output_tuple->on_device_shape()); Tensor* output_tensor; diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc index 1a5bfac337baf773b84b92af5f88ef7a4c8ba81f..6a7f10652533920ba3fa48fba1d5161f7c4d4530 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc @@ -37,6 +37,17 @@ REGISTER_KERNEL_BUILDER(Name("XRTAllocate") .HostMemory("handle"), XRTAllocateOp); +REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor") + .Device(DEVICE_XLA_GPU) + .HostMemory("inputs") + .HostMemory("handle"), + XRTAllocateFromTensorOp); +REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor") + .Device(DEVICE_XLA_CPU) + .HostMemory("inputs") + .HostMemory("handle"), + XRTAllocateFromTensorOp); + REGISTER_KERNEL_BUILDER(Name("XRTSubTuple") .Device(DEVICE_XLA_GPU) .HostMemory("base_handle") diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index 2e2f3ff116a7b331df8dbd58a9fe40096f524140..e2c223b3dbb2311d0f42e1a36e316fd9d5f66040 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -19,10 +19,14 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ #define TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ +#include #include #include +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -30,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt.pb.h" #include "tensorflow/compiler/xrt/xrt_device.h" #include "tensorflow/compiler/xrt/xrt_state.h" +#include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" @@ -200,6 +205,109 @@ class XRTAllocateOp : public OpKernel { } }; +// Op that allocates memory for a tensor (with optional layout) and transfers it +// to the device, returning an allocation handle. +template +class XRTAllocateFromTensorOp : public OpKernel { + public: + explicit XRTAllocateFromTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + bool make_tuple = false; + OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &tf_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("make_tuple", &make_tuple)); + if (ctx->HasAttr("layouts")) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major_)); + } + OP_REQUIRES( + ctx, tf_shapes_.size() == dtypes_.size(), + errors::InvalidArgument("shapes and dtypes must be the same length")); + std::vector xla_shapes; + for (int i = 0; i < tf_shapes_.size(); i++) { + xla::Shape xla_shape; + OP_REQUIRES_OK( + ctx, TensorShapeToXLAShape(dtypes_[i], tf_shapes_[i], &xla_shape)); + xla_shapes.push_back(xla_shape); + } + if (xla_shapes.size() > 1 || make_tuple) { + shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes); + } else { + shape_.Swap(&xla_shapes.front()); + } + if (!minor_to_major_.empty()) { + xla::Shape shape_with_layouts; + OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major_, + /*layout_func=*/nullptr, + &shape_with_layouts)); + shape_.Swap(&shape_with_layouts); + } + } + + ~XRTAllocateFromTensorOp() override = default; + XRTAllocateFromTensorOp(const XRTAllocateFromTensorOp&) = delete; + XRTAllocateFromTensorOp& operator=(const XRTAllocateFromTensorOp&) = delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTAllocateFromTensorOp::Compute"; + + OpInputList values; + OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &values)); + OP_REQUIRES(ctx, values.size() == tf_shapes_.size(), + errors::InvalidArgument( + "Wrong number of inputs to XRTAllocateFromTensor: ", + values.size(), " vs. ", tf_shapes_.size())); + + std::vector tensors_data; + for (size_t i = 0; i < values.size(); ++i) { + const Tensor& input_tensor = values[i]; + OP_REQUIRES(ctx, input_tensor.dtype() == dtypes_[i], + errors::InvalidArgument( + "Input tensor type and input dtype do not match")); + // We allow the requested on-device shape to differ from the shape of the + // input tensor, as long as they have the same number of elements. + OP_REQUIRES( + ctx, + input_tensor.shape().num_elements() == tf_shapes_[i].num_elements(), + errors::InvalidArgument( + "Input tensor must have the number of elements specified " + "in the matching input shape: ", + input_tensor.shape().num_elements(), " vs. ", + tf_shapes_[i].num_elements(), " at index ", i)); + tensors_data.push_back( + static_cast(DMAHelper::base(&input_tensor))); + } + // Use the buffer straight out of the input tensors to create the literal. + xla::BorrowingLiteral literal = + shape_.IsTuple() ? xla::BorrowingLiteral(tensors_data, shape_) + : xla::BorrowingLiteral(tensors_data.front(), shape_); + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + + // We are guaranteed that the underlying device object won't be deleted out + // from under us, while the ScopedRef is live. + class DeviceAccessor::ScopedRef device_ref; + OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); + + XRTTupleAllocation* allocation; + OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( + literal, device_ref.backend(), + device_ref.device_ordinal(), &allocation)); + + // Intern takes ownership of our reference to allocation. + int64 key; + OP_REQUIRES_OK(ctx, allocation->Intern(rm, &key)); + + Tensor output(DT_INT64, TensorShape({})); + output.scalar()() = key; + ctx->set_output(0, output); + } + + private: + std::vector tf_shapes_; + DataTypeVector dtypes_; + std::vector minor_to_major_; + xla::Shape shape_; +}; + // Op that takes a tuple handle input and returns a handle to a sub-tuple of the // input. template @@ -453,17 +561,17 @@ class XRTReleaseAllocationOp : public OpKernel { void Compute(OpKernelContext* ctx) override { VLOG(1) << "XRTReleaseAllocationOp::Compute"; - const Tensor& allocation_handle = ctx->input(0); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(allocation_handle.shape()), - errors::Internal("handle input should be an int64 scalar")); - int64 key = allocation_handle.scalar()(); - ResourceMgr* rm; OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(rm, key)); - - VLOG(2) << "Released allocation handle " << key; + const Tensor& allocation_handle = ctx->input(0); + auto flat_keys = allocation_handle.flat(); + for (int64 i = 0; i < flat_keys.size(); ++i) { + int64 key = flat_keys(i); + OP_REQUIRES_OK(ctx, + XRTTupleAllocation::DeleteFromResourceManager(rm, key)); + VLOG(2) << "Released allocation handle " << key; + } } }; diff --git a/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc index 7b3b50c69559f6003a108fdf6a1325dbdbaa80a6..9dd964e5467cd855d67764a512e95a6a18f482e1 100644 --- a/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc @@ -44,10 +44,10 @@ REGISTER_OP("XRTReleaseCompilationHandle") .SetShapeFn(tensorflow::shape_inference::NoOutputs) .Doc( R"( -Discards a computation from the compilation cache. The handle cannot be -subsequently used. +Discards one or more computation handles from the compilation cache. +The handle(s) cannot be subsequently used. -'handle' is an id returned from a XRTCompile Op. +'handle' is an ID (or vector of IDs) returned from a XRTCompile Op. )"); } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc index fe6bee0dacf5dc2050613fc9ad34d3235b5a7b63..2e743fec4963a52ee1abf64525f26e3d89479670 100644 --- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc @@ -26,12 +26,41 @@ REGISTER_OP("XRTAllocate") .SetShapeFn(tensorflow::shape_inference::ScalarShape) .Doc( R"( -Reads a literal proto and transfers it to TPU device memory. +Reads a literal proto and transfers it to device memory. -'allocation' is a serialized xrt::TPUAllocation proto. +'allocation' is a serialized xrt::XLAAllocation proto. 'handle' is an id that can be used in other ops to refer to the allocation. )"); +REGISTER_OP("XRTAllocateFromTensor") + .Input("inputs: dtypes") + .Output("handle: int64") + .Attr("dtypes: list(type)") + .Attr("shapes: list(shape)") + .Attr("layouts: list(int) = []") + .Attr("make_tuple: bool = false") + .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Doc( + R"( +Reads a list of tensors with optional layouts, and transfers it to device +memory. + +inputs: The tensors holding the input data. +shapes: The shapes which the tensors should have on device. The i-th shape +corresponds to the i-th input. The shapes, together with the (optional) +layouts, helps creating the fully qualified shape of the data on the device. +The shapes can differ from the corresponding input one, as long as the total +number of elements matches. In other words, it is possible to feed an input +tensor with shape {8} and have a corresponding shape {2,2,2}. +layouts: A vector holding the requested layout in minor-to-major sequence. +If empty, the default layout wil be used. +For a tuple, the layouts vector holds a linearized minor-to-major numbers +for all the tuple leaves, in the order they appear within the tuple. +The elements within the layouts sequence corresponding to a given tuple +subshape can be set to -1, to leave such subshape to the default shape. +handle: An id that can be used in other ops to refer to the allocation. +)"); + REGISTER_OP("XRTSubTuple") .Input("base_handle: int64") .Input("shape_index: int32") @@ -127,10 +156,11 @@ REGISTER_OP("XRTReleaseAllocationHandle") .SetShapeFn(tensorflow::shape_inference::NoOutputs) .Doc( R"( -Discards an allocation from device memory. The handle cannot be subsequently +Discards one or more device memory handles. The handle(s) cannot be subsequently used. -'handle' is the id returned from the Op that produced the on-device allocation. +'handle' is the ID (or a vector of IDs) returned from the Op that produced the +on-device allocation. )"); REGISTER_OP("XRTReleaseAllAllocations") diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD index be44a3474acdeb9905c1d21b932fa0dd10b5a212..3a19327e5b5d8072fbecdbe10e9959c8491780eb 100644 --- a/tensorflow/compiler/xrt/tests/BUILD +++ b/tensorflow/compiler/xrt/tests/BUILD @@ -24,6 +24,7 @@ cc_library( "//tensorflow/cc:client_session", "//tensorflow/cc:ops", "//tensorflow/cc:scope", + "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 5f8121703e108f26b048feb7a0412a282f52892c..1111f8240512e81c10a42a28c09f5b0a94daf1ee 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -53,6 +55,14 @@ string DeviceFromFlag() { return absl::StrCat("/device:", xla_test_device, ":0"); } +std::vector GetAttrLayout(absl::Span minor_to_mayor) { + std::vector layout; + for (auto dim : minor_to_mayor) { + layout.push_back(static_cast(dim)); + } + return layout; +} + xla::LiteralProto TwoElementTuple() { auto array = xla::LiteralUtil::CreateR1({1.0f, 3.0f}); auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}); @@ -96,14 +106,21 @@ xla::LiteralProto FloatMatrix( return array.ToProto(); } +xla::Literal ReadOutputLiteral(const std::vector& outputs, size_t idx) { + xla::LiteralProto response; + CHECK(response.ParseFromString(outputs[idx].scalar()())); + return xla::Literal::CreateFromProto(response).ValueOrDie(); +} + bool CompareLiteralProtos(const xla::LiteralProto& a, const xla::LiteralProto& b) { auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie(); auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie(); bool equal = l_a == l_b; if (!equal) { - LOG(INFO) << "LiteralProtos don't match: " << a.DebugString() - << " != " << b.DebugString(); + LOG(INFO) << "LiteralProtos don't match:\n" + << a.DebugString() << "\n!=\n" + << b.DebugString(); } return equal; } @@ -113,8 +130,19 @@ bool CompareLiteralToLiteralProto(const xla::Literal& a, auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie(); bool equal = a == l_b; if (!equal) { - LOG(INFO) << "Literal and LiteralProto don't match " - << a.ToProto().DebugString() << " != " << b.DebugString(); + LOG(INFO) << "Literal and LiteralProto don't match:\n" + << a.ToProto().DebugString() << "\n!=\n" + << b.DebugString(); + } + return equal; +} + +bool CompareLiterals(const xla::Literal& a, const xla::Literal& b) { + bool equal = a == b; + if (!equal) { + LOG(INFO) << "Literals don't match:\n" + << a.ToProto().DebugString() << "\n!=\n" + << b.ToProto().DebugString(); } return equal; } @@ -215,6 +243,120 @@ xla::ProgramShape XlaCompiledProgramShape( ->ComputeProgramShape(); } +TEST(RawApiTest, AllocFromTensor) { + xla::Literal literal = + xla::LiteralUtil::CreateR2({{4.0f, 5.0f}, {6.0f, 7.0f}}); + Tensor tensor; + TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor)); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + std::vector layout = + GetAttrLayout(literal.shape().layout().minor_to_major()); + ops::XRTAllocateFromTensor::Attrs alloc_attrs = + ops::XRTAllocateFromTensor::Layouts(layout); + auto handle = + ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs); + auto read_back = ops::XRTReadLiteralAndRelease(root, handle); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); +} + +TEST(RawApiTest, AllocFromTensorTuple) { + xla::Literal literal0 = + xla::LiteralUtil::CreateR2({{4.0f, 5.0f}, {6.0f, 7.0f}}); + xla::Literal literal1 = + xla::LiteralUtil::CreateR2({{14.0f, -5.0f}, {16.0f, 17.0f}}); + xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1}); + Tensor tensor0; + TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0)); + Tensor tensor1; + TF_ASSERT_OK(LiteralToHostTensor(literal1, DT_FLOAT, &tensor1)); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + std::vector layout = GetShapeLayoutVector(literal.shape()).ValueOrDie(); + ops::XRTAllocateFromTensor::Attrs alloc_attrs = + ops::XRTAllocateFromTensor::Layouts(layout); + auto handle = ops::XRTAllocateFromTensor(root, {tensor0, tensor1}, + {tensor0.shape(), tensor1.shape()}, + alloc_attrs); + auto read_back = ops::XRTReadLiteralAndRelease(root, handle); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); +} + +TEST(RawApiTest, AllocFromTensorTupleSingle) { + xla::Literal literal0 = + xla::LiteralUtil::CreateR2({{4.0f, 5.0f}, {6.0f, 7.0f}}); + xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0}); + Tensor tensor0; + TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0)); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + std::vector layout = GetShapeLayoutVector(literal.shape()).ValueOrDie(); + ops::XRTAllocateFromTensor::Attrs alloc_attrs = + ops::XRTAllocateFromTensor::Layouts(layout).MakeTuple(true); + auto handle = ops::XRTAllocateFromTensor(root, {tensor0}, {tensor0.shape()}, + alloc_attrs); + auto read_back = ops::XRTReadLiteralAndRelease(root, handle); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); +} + +TEST(RawApiTest, AllocFromTensorRelayout) { + xla::Literal literal = + xla::LiteralUtil::CreateR2({{4.0f, 5.0f}, {6.0f, 7.0f}}); + Tensor tensor; + TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor)); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + // Use inverse array layout with the tensor data above. + std::vector layout({0, 1}); + ops::XRTAllocateFromTensor::Attrs alloc_attrs = + ops::XRTAllocateFromTensor::Layouts(layout); + auto handle = + ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs); + auto read_back = ops::XRTReadLiteralAndRelease(root, handle); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + // We have sent literal's data (in array layout) with a attribute layout + // {0,1}, so the expected literal read from device needs to be changed + // accordingly. + xla::Literal expected_literal = + xla::LiteralUtil::CreateR2({{4.0f, 6.0f}, {5.0f, 7.0f}}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected_literal, response)); +} + TEST(RawApiTest, AllocAndRewrite) { xrt::XLAAllocation alloc; *alloc.mutable_value() = @@ -258,8 +400,102 @@ TEST(RawApiTest, AllocAndRewrite) { EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response)); - auto release = - ops::XRTReleaseAllocationHandle(root, Input(allocation_handle)); + Tensor release_tensor(DT_INT64, TensorShape({1})); + release_tensor.flat()(0) = allocation_handle; + + auto release = ops::XRTReleaseAllocationHandle(root, release_tensor); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, + &outputs)); +} + +TEST(RawApiTest, AllocReleaseMany) { + xrt::XLAAllocation alloc1; + *alloc1.mutable_value() = + xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); + xrt::XLAAllocation alloc2; + *alloc2.mutable_value() = + xla::LiteralUtil::CreateR2({{6, 7}, {4, 5}}).ToProto(); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto value1 = + ops::Const(root.WithDevice("/device:CPU:0"), alloc1.SerializeAsString()); + auto value2 = + ops::Const(root.WithDevice("/device:CPU:0"), alloc2.SerializeAsString()); + auto handle1 = ops::XRTAllocate(root, value1); + auto handle2 = ops::XRTAllocate(root, value2); + TF_ASSERT_OK(root.status()); + + tensorflow::ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({handle1, handle2}, &outputs)); + EXPECT_EQ(outputs.size(), 2); + + int64 allocation_handle1 = outputs[0].scalar()(); + int64 allocation_handle2 = outputs[1].scalar()(); + + Tensor release_tensor(DT_INT64, TensorShape({2})); + release_tensor.flat()(0) = allocation_handle1; + release_tensor.flat()(1) = allocation_handle2; + + auto release = ops::XRTReleaseAllocationHandle(root, release_tensor); + outputs.clear(); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, + &outputs)); +} + +TEST(RawApiTest, CompileAndReleaseMany) { + xrt::XLAComputation c1; + auto config1 = c1.mutable_config(); + auto shapes1 = config1->mutable_program_shape(); + *shapes1->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes1->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes1->mutable_result() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + StoreComputationSnapshot(AddAndScale(), c1.mutable_hlo_snapshot()); + + xrt::XLAComputation c2; + auto config2 = c2.mutable_config(); + auto shapes2 = config2->mutable_program_shape(); + *shapes2->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes2->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes2->mutable_result() = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})}) + .ToProto(); + StoreComputationSnapshot(AddAndTuple(), c2.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(false); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto computation1 = + ops::Const(root.WithDevice("/device:CPU:0"), c1.SerializeAsString()); + auto c_handle1 = ops::XRTCompile(root, computation1); + auto computation2 = + ops::Const(root.WithDevice("/device:CPU:0"), c2.SerializeAsString()); + auto c_handle2 = ops::XRTCompile(root, computation2); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({c_handle1.handle, c_handle2.handle}, &outputs)); + EXPECT_EQ(outputs.size(), 2); + + int64 compilation_handle1 = outputs[0].scalar()(); + int64 compilation_handle2 = outputs[1].scalar()(); + + Tensor release_tensor(DT_INT64, TensorShape({2})); + release_tensor.flat()(0) = compilation_handle1; + release_tensor.flat()(1) = compilation_handle2; + + auto release = ops::XRTReleaseCompilationHandle(root, release_tensor); + outputs.clear(); TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, &outputs)); } @@ -845,6 +1081,107 @@ TEST(RawApiTest, LeakCompilationReference) { TF_EXPECT_OK(session.Run({c_handle.handle}, &outputs)); } +TEST(RawApiTest, CompileAndExecuteWithReusedBuffers) { + xla::Shape element_shape = xla::ShapeUtil::MakeShape(xla::F32, {2}); + xla::Shape shape = + xla::ShapeUtil::MakeTupleShape({element_shape, element_shape}); + xla::Shape return_shape = xla::ShapeUtil::MakeTupleShape( + {element_shape, element_shape, element_shape, element_shape}); + xla::XlaBuilder builder("ReuseBuffer"); + auto param = xla::Parameter(&builder, 0, shape, "param"); + auto p0 = xla::GetTupleElement(param, 0); + auto p1 = xla::GetTupleElement(param, 1); + auto add = xla::Add(p0, p1); + auto sub = xla::Sub(p0, p1); + xla::Tuple(&builder, {add, sub, p0, p1}); + + // Flip the tuple literals in the input handle. + builder.SetUpAlias({1}, 0, {0}); + builder.SetUpAlias({0}, 0, {1}); + + auto computation = builder.Build().ValueOrDie(); + + auto literal0 = xla::LiteralUtil::CreateR1({1.0f, 2.0f}); + auto literal1 = xla::LiteralUtil::CreateR1({5.0f, 9.0f}); + auto literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1}); + + xrt::XLAAllocation param_alloc; + *param_alloc.mutable_value() = literal.ToProto(); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = shape.ToProto(); + *shapes->mutable_result() = return_shape.ToProto(); + StoreComputationSnapshot(computation, c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(false); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + ClientSession session(root); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto c_data = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, c_data); + auto param_value = ops::Const(root.WithDevice("/device:CPU:0"), + param_alloc.SerializeAsString()); + auto param_handle = ops::XRTAllocate(root, param_value); + TF_ASSERT_OK(root.status()); + + std::vector outputs; + TF_EXPECT_OK(session.Run({param_handle}, &outputs)); + + int64 alloc_handle = outputs[0].scalar()(); + + // Note that we release the result handle immediately, but since we aliased + // the output buffers onto the input allocation ones (held in alloc_handle), + // we can fetch the result from there. + auto result = + ops::XRTExecute(root, c_handle.handle, e_config, {Input(alloc_handle)}); + auto read_back = ops::XRTReadLiteral(root, result); + auto release = ops::XRTReleaseAllocationHandle( + root.WithControlDependencies(read_back), result); + TF_ASSERT_OK(root.status()); + + outputs.clear(); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {read_back}, + {release}, &outputs)); + + xla::Literal exec_literal = ReadOutputLiteral(outputs, 0); + auto exec_literal_parts = exec_literal.DecomposeTuple(); + ASSERT_EQ(exec_literal_parts.size(), 4); + + EXPECT_TRUE(CompareLiterals(exec_literal_parts[2], literal0)); + EXPECT_TRUE(CompareLiterals(exec_literal_parts[3], literal1)); + + // Now we read back the original input handle values, which at this point + // should contain the result of the XLA computation. + auto read_handle = ops::XRTReadLiteral(root, Input(alloc_handle)); + TF_ASSERT_OK(root.status()); + auto release_handle = ops::XRTReleaseAllocationHandle( + root.WithControlDependencies(read_handle), Input(alloc_handle)); + TF_ASSERT_OK(root.status()); + + outputs.clear(); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {read_handle}, + {release_handle}, &outputs)); + + xla::Literal return_literal = ReadOutputLiteral(outputs, 0); + + auto expected_literal0 = xla::LiteralUtil::CreateR1({6.0f, 11.0f}); + auto expected_literal1 = xla::LiteralUtil::CreateR1({-4.0f, -7.0f}); + // The first element of the computation returned tuple would be the add + // (expected_literal0), but since we flipped the buffers, the sub + // (expected_literal1) should come first. + auto expected_literal = + xla::LiteralUtil::MakeTuple({&expected_literal1, &expected_literal0}); + + EXPECT_TRUE(CompareLiterals(return_literal, expected_literal)); +} + TEST(RawApiTest, CompileAndExecuteWithS64Argument) { xrt::XLAAllocation p0; *p0.mutable_value() = xla::LiteralUtil::CreateR0(11031965).ToProto(); @@ -862,6 +1199,7 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) { xrt::XRTExecutionConfig e; e.set_release_input_handles(true); e.set_release_compilation_handle(true); + e.set_return_exploded_tuple(true); Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); auto e_config = diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.cc b/tensorflow/compiler/xrt/xrt_compilation_cache.cc index d1405eae468492748ae88d842334a922dce272c6..8bf0f28d2233d9e7593365bc42187e327a1c4ac4 100644 --- a/tensorflow/compiler/xrt/xrt_compilation_cache.cc +++ b/tensorflow/compiler/xrt/xrt_compilation_cache.cc @@ -273,6 +273,8 @@ Status XRTCompilationCache::Lookup( return Status::OK(); } -string XRTCompilationCache::DebugString() { return "XRTCompilationCache"; } +string XRTCompilationCache::DebugString() const { + return "XRTCompilationCache"; +} } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.h b/tensorflow/compiler/xrt/xrt_compilation_cache.h index c43d0fc47873abdc82ee937c155bebc346a05f17..7398e847d8b744f947adb03e1bcfd5c0a5b2cc55 100644 --- a/tensorflow/compiler/xrt/xrt_compilation_cache.h +++ b/tensorflow/compiler/xrt/xrt_compilation_cache.h @@ -118,7 +118,7 @@ class XRTCompilationCache : public ResourceBase { // EntryRef holding the program is returned in entry. Status Lookup(int64 uid, std::unique_ptr* entry); - string DebugString() override; + string DebugString() const override; private: // An entry in the compilation cache. The entry is deleted once it has been diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index 343460ff107fa81be127950837f786fe4eeadf26..1e2a9584f88b73d7c92a929e93af60376a59170b 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -133,7 +133,8 @@ Status AllocateScopedShapedBuffer( XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation, int device_ordinal, xla::DeviceMemoryAllocator* allocator) - : allocation_(allocation), + : size_(allocation.size()), + allocation_(allocation), device_ordinal_(device_ordinal), allocator_(allocator) { if (VLOG_IS_ON(2)) { @@ -181,7 +182,7 @@ XRTTupleAllocation::~XRTTupleAllocation() { } /*static*/ Status XRTTupleAllocation::CreateAndTransfer( - const xla::Literal& literal, xla::Backend* backend, int device_ordinal, + const xla::LiteralBase& literal, xla::Backend* backend, int device_ordinal, XRTTupleAllocation** allocation) { auto transfer_manager = backend->transfer_manager(); auto allocator = backend->memory_allocator(); @@ -223,8 +224,19 @@ Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, xla::Literal* literal) { auto transfer_manager = backend->transfer_manager(); TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal)); + + // Validate the allocation buffers as if nulls gets to + // TransferLiteralFromDevice() a CHECK is issued. + xla::ShapedBuffer shaped_buffer = ToShapedBuffer(); + for (auto& index_buffer : shaped_buffer.buffers()) { + if (index_buffer.second.is_null()) { + return errors::InvalidArgument("Literal buffer at index ", + index_buffer.first.ToString(), + " has been released"); + } + } TF_ASSIGN_OR_RETURN(*literal, transfer_manager->TransferLiteralFromDevice( - stream.get(), ToShapedBuffer())); + stream.get(), shaped_buffer)); return Status::OK(); } @@ -505,11 +517,34 @@ xla::ShapedBuffer XRTTupleAllocation::ToShapedBuffer() { return shaped_buffer; } +Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source, + const xla::ShapeIndex& source_index, + const xla::ShapeIndex& dest_index) { + XRTBufferAllocation* source_buffer = source.buffers_.element(source_index); + XRTBufferAllocation* dest_buffer = buffers_.element(dest_index); + // We allow the destination size being zero, because there are cases where we + // are coming in later filling in null/uninitialized device buffers. + // In all other cases, the size of the new buffer must match. + if (source_buffer->size() != dest_buffer->size() && + dest_buffer->size() != 0) { + return errors::InvalidArgument( + "Source buffer at index ", source_index.ToString(), + " does not match the size of destination buffer at index ", + dest_index.ToString(), ": ", source_buffer->size(), " vs ", + dest_buffer->size()); + } + *buffers_.mutable_element(dest_index) = source_buffer; + source_buffer->Ref(); + dest_buffer->Unref(); + return Status::OK(); +} + xla::ShapeTree -XRTTupleAllocation::ToDeviceMemoryTree(bool release) { +XRTTupleAllocation::ToDeviceMemoryTree( + const std::function& release_checker) { xla::ShapeTree shaped_tree(on_device_shape()); for (const auto& buffer : buffers_) { - if (!release) { + if (!release_checker(buffer.first)) { *shaped_tree.mutable_element(buffer.first) = buffer.second->allocation(); } else { *shaped_tree.mutable_element(buffer.first) = xla::OwningDeviceMemory( diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 3e3d5024124e13b87eed6f79596d50cd64325914..ddf2656e6f51775024a6d1cd0d7a387605faae6f 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ #define TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ +#include #include #include #include @@ -58,7 +59,14 @@ class XRTBufferAllocation : public core::RefCounted { // freed when the reference count drops to zero. void DiscardAllocation(); + // Returns the expected size of the allocation. Since DiscardAllocation() will + // set allocation_ to {null,0}, and since later we might want to replace the + // discarded buffer with a new one, we need to be able to verify the size + // compatibility. + uint64 size() const { return size_; } + private: + uint64 size_ = 0; se::DeviceMemoryBase allocation_; int device_ordinal_; xla::DeviceMemoryAllocator* allocator_; @@ -80,7 +88,7 @@ class XRTTupleAllocation : public ResourceBase { // Allocates new device memory buffers sufficient to store literal, transfers // literal to that memory, and returns a XRTTupleAllocation handle to the // allocated buffers. - static Status CreateAndTransfer(const xla::Literal& literal, + static Status CreateAndTransfer(const xla::LiteralBase& literal, xla::Backend* backend, int device_ordinal, XRTTupleAllocation** allocation); @@ -168,11 +176,20 @@ class XRTTupleAllocation : public ResourceBase { // the same shape as on_host_shape. xla::ShapedBuffer ToShapedBuffer(); - // Returns the device memory tree of this allocation. If 'release' is set, the - // ownership of the device memory is transferred to the result. - xla::ShapeTree ToDeviceMemoryTree(bool release); + // Aliases the source buffer at source_index into the current tuple allocation + // dest_index. + Status AliasBufferFrom(const XRTTupleAllocation& source, + const xla::ShapeIndex& source_index, + const xla::ShapeIndex& dest_index); + + // Returns the device memory tree of this allocation. If the release_checker + // function returns true for a given index, the ownership of the device memory + // at that index is transferred to the result. Every attempt to read the value + // at that index will fail. + xla::ShapeTree ToDeviceMemoryTree( + const std::function& release_checker); - string DebugString() override { return "XLA allocation handle"; } + string DebugString() const override { return "XLA allocation handle"; } private: // Creates a new handle with (tuple) shape. diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 832db0f4ab46911e067d17b4a125706c276cf798..a4c3d9623adfe3133af0c6ea055586b9544e659b 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -63,7 +63,6 @@ py_library( "//tensorflow/contrib/libsvm", "//tensorflow/contrib/linear_optimizer:sdca_estimator_py", "//tensorflow/contrib/linear_optimizer:sdca_ops_py", - "//tensorflow/contrib/lite/python:lite", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/contrib/losses:losses_py", "//tensorflow/contrib/losses:metric_learning_py", @@ -197,7 +196,7 @@ cc_library( "//tensorflow/contrib/kinesis:dataset_kernels", ], }) + if_not_windows([ - "//tensorflow/contrib/tensorrt:trt_engine_op_kernel", + "//tensorflow/contrib/tensorrt:trt_op_kernels", ]), ) diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 4f1a2a5693235183c8f486817b82c8c81fa389ec..48d5296c71cbdb470fa405b30547a32b7022f29b 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -20,13 +20,14 @@ from __future__ import division from __future__ import print_function import os +import platform # Add projects here, they will show up under tf.contrib. from tensorflow.contrib import autograph from tensorflow.contrib import batching from tensorflow.contrib import bayesflow from tensorflow.contrib import checkpoint -if os.name != "nt": +if os.name != "nt" and platform.machine() != "s390x": from tensorflow.contrib import cloud from tensorflow.contrib import cluster_resolver from tensorflow.contrib import coder @@ -91,7 +92,6 @@ from tensorflow.contrib import tpu from tensorflow.contrib import training from tensorflow.contrib import util from tensorflow.contrib.eager.python import tfe as eager -from tensorflow.contrib.lite.python import lite from tensorflow.contrib.optimizer_v2 import optimizer_v2_symbols as optimizer_v2 from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field from tensorflow.contrib.recurrent.python import recurrent_api as recurrent @@ -103,6 +103,8 @@ from tensorflow.python.util.lazy_loader import LazyLoader ffmpeg = LazyLoader("ffmpeg", globals(), "tensorflow.contrib.ffmpeg") del os +del platform + del LazyLoader del absolute_import diff --git a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb index 44532cb078f9bd1578172f8a7d8a4b55cd21a7cb..831c613f2c8c9a4fcc2cb9d313077fe79ee96fd7 100644 --- a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb +++ b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb @@ -186,8 +186,8 @@ "\n", " def __init__(self):\n", " super(RnnColorbot, self).__init__()\n", - " self.lower_cell = tf.contrib.rnn.LSTMBlockCell(256)\n", - " self.upper_cell = tf.contrib.rnn.LSTMBlockCell(128)\n", + " self.lower_cell = tf.contrib.rnn.LSTMBlockCell(256, dtype=tf.float32)\n", + " self.upper_cell = tf.contrib.rnn.LSTMBlockCell(128, dtype=tf.float32)\n", " self.relu_layer = tf.layers.Dense(3, activation=tf.nn.relu)\n", "\n", " def _rnn_layer(self, chars, cell, batch_size, training):\n", @@ -241,7 +241,7 @@ " seq = self._rnn_layer(seq, self.upper_cell, batch_size, training)\n", "\n", " # Grab just the end-of-sequence from each output.\n", - " indices = (length - 1, range(batch_size))\n", + " indices = (length - 1, list(range(batch_size)))\n", " indices = tf.stack(indices, 1)\n", " sequence_ends = tf.gather_nd(seq, indices)\n", " return self.relu_layer(sequence_ends)\n", diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index 648f3ebb05646a66144bcb118347cbc391909409..5174afe0a63d37e3ea3e19ac9bab644d1d83ecf1 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -37,6 +37,7 @@ py_library( cc_library( name = "batch_ops_kernels", deps = [ + "//tensorflow/core:batch_ops_op_lib", "//tensorflow/core/kernels:batch_kernels", ], alwayslink = 1, diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h index 4652021fecabfa11fa6a8754dc884d89e151b590..e3b4535bac4a01a1277290e0d1ea6d3c7613731c 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h @@ -42,7 +42,7 @@ class BigtableClientResource : public ResourceBase { return client_; } - string DebugString() override { + string DebugString() const override { return strings::StrCat("BigtableClientResource(project_id: ", project_id_, ", instance_id: ", instance_id_, ")"); } @@ -67,7 +67,7 @@ class BigtableTableResource : public ResourceBase { ::google::cloud::bigtable::noex::Table& table() { return table_; } - string DebugString() override { + string DebugString() const override { return strings::StrCat( "BigtableTableResource(client: ", client_->DebugString(), ", table: ", table_name_, ")"); diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc index 3fe71a2ea730cc9b60b2e2088a0d80a08b38d1a9..e6fda9e61757f1441b3691c2a3d57c6f1a5a0d42 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h" -#include "google/bigtable/v2/data.pb.h" +#include "external/com_github_googleapis_googleapis/google/bigtable/v2/data.pb.h" #include "google/protobuf/wrappers.pb.h" #include "re2/re2.h" #include "tensorflow/core/lib/strings/stringprintf.h" @@ -410,6 +410,17 @@ BigtableTestClient::AsyncCheckAndMutateRow( return nullptr; } +std::unique_ptr< + grpc::ClientAsyncReaderInterface> +BigtableTestClient::AsyncReadRows( + grpc::ClientContext* context, + const google::bigtable::v2::ReadRowsRequest& request, + grpc::CompletionQueue* cq, void* tag) { + LOG(WARNING) << "Call to InMemoryDataClient::" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + std::shared_ptr BigtableTestClient::Channel() { LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely " "cause a crash!"; diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h index 85705904573e9e7710912e3f4ff30dd8fed5bf85..8e1326f2ce841368ea81fc7194a0588e5d6cd637 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h @@ -87,6 +87,12 @@ class BigtableTestClient : public ::google::cloud::bigtable::DataClient { const google::bigtable::v2::CheckAndMutateRowRequest& request, grpc::CompletionQueue* cq) override; + std::unique_ptr< + grpc::ClientAsyncReaderInterface> + AsyncReadRows(grpc::ClientContext* context, + const google::bigtable::v2::ReadRowsRequest& request, + grpc::CompletionQueue* cq, void* tag) override; + std::shared_ptr Channel() override; private: diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index ee052ac60387d8f993e4942dd7dff39e191dd3a4..47d910d42a27db4b857eeb12209dfbb429dd1be2 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -487,8 +487,8 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self.assertTrue(frac_below_upper_0 <= 0.98) self.assertTrue(frac_below_upper_1 >= 0.92) self.assertTrue(frac_below_upper_1 <= 0.98) - self.assertTrue(frac_both_below_upper >= 0.92) - self.assertTrue(frac_both_below_upper <= 0.98) + self.assertTrue(frac_both_below_upper >= 0.91) + self.assertTrue(frac_both_below_upper <= 0.99) train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( two_dimension=True) @@ -516,8 +516,8 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self.assertTrue(frac_above_lower_0 <= 0.98) self.assertTrue(frac_above_lower_1 >= 0.92) self.assertTrue(frac_above_lower_1 <= 0.98) - self.assertTrue(frac_both_above_lower >= 0.92) - self.assertTrue(frac_both_above_lower <= 0.98) + self.assertTrue(frac_both_above_lower >= 0.91) + self.assertTrue(frac_both_above_lower <= 0.99) class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): @@ -806,8 +806,8 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): self.assertTrue(frac_below_upper_0 <= 0.98) self.assertTrue(frac_below_upper_1 >= 0.92) self.assertTrue(frac_below_upper_1 <= 0.98) - self.assertTrue(frac_both_below_upper >= 0.92) - self.assertTrue(frac_both_below_upper <= 0.98) + self.assertTrue(frac_both_below_upper >= 0.91) + self.assertTrue(frac_both_below_upper <= 0.99) train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( two_dimension=True) @@ -835,8 +835,8 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): self.assertTrue(frac_above_lower_0 <= 0.98) self.assertTrue(frac_above_lower_1 >= 0.92) self.assertTrue(frac_above_lower_1 <= 0.98) - self.assertTrue(frac_both_above_lower >= 0.92) - self.assertTrue(frac_both_above_lower <= 0.98) + self.assertTrue(frac_both_above_lower >= 0.91) + self.assertTrue(frac_both_above_lower <= 0.99) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc index e446c411a8d5075563b8f8b912b29df310e16c8c..6faf6963011b698a3b233329d87471da7608e44a 100644 --- a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc @@ -96,7 +96,7 @@ class StatsAccumulatorResource : public boosted_trees::StampedResource { TensorShapeUtils::IsScalar(hessian_shape)); } - string DebugString() override { + string DebugString() const override { return strings::StrCat("StatsAccumulatorResource[size=", values_.size(), "]"); } diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py index 42d69645acaae063fcd46bd1f6c819ccb68f48bd..aa3f24f08a0f762507df83def72e7d595265221f 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py @@ -227,7 +227,7 @@ class ModelOpsTest(test_util.TensorFlowTestCase): tree_ensemble_config=tree_ensemble_config.SerializeToString(), name="restore_tree") resources.initialize_resources(resources.shared_resources()).run() - variables.initialize_all_variables().run() + variables.global_variables_initializer().run() my_saver = saver.Saver() # Add the second tree and replace the ensemble of the handle. diff --git a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py index fca22c71a83459cb290eaebcf107cf1c14c222b7..c3685b54e201f73039f6623443c67ba2b217a51e 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py @@ -62,8 +62,8 @@ class TreeEnsembleVariableSavable(saver.BaseSaverBuilder.SaveableObject): saver.BaseSaverBuilder.SaveSpec(ensemble_config, slice_spec, name + "_config"), ] - super(TreeEnsembleVariableSavable, - self).__init__(tree_ensemble_handle, specs, name) + super(TreeEnsembleVariableSavable, self).__init__(tree_ensemble_handle, + specs, name) self._tree_ensemble_handle = tree_ensemble_handle self._create_op = create_op @@ -115,7 +115,7 @@ class TreeEnsembleVariable(tracking.TrackableResource): def _gather_saveables_for_checkpoint(self): return { - "tree_ensemble_variable": + self.resource_handle.op.name + "/tree_ensemble_variable": functools.partial( TreeEnsembleVariableSavable, tree_ensemble_handle=self.resource_handle, @@ -131,8 +131,8 @@ def tree_ensemble_variable(stamp_token, Args: stamp_token: The initial stamp token value for the ensemble resource. - tree_ensemble_config: A `Tensor` of type `string`. - Serialized proto of the tree ensemble. + tree_ensemble_config: A `Tensor` of type `string`. Serialized proto of the + tree ensemble. name: A name for the ensemble variable. container: An optional `string`. Defaults to `""`. diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index a5951fb7377d48748f5eb578c034176517df7749..e78ec476ab3b43e5eb56a2502008bb8020ae97e0 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -566,9 +566,10 @@ class GradientBoostedDecisionTreeModel(object): # Determine if ensemble is colocated with the inputs. if self._ensemble_handle.device != input_deps[0].device: # Create a local ensemble and get its local stamp. - with ops.name_scope("local_ensemble", "TreeEnsembleVariable") as name: + with ops.name_scope("local_ensemble", "TreeEnsembleVariable"): local_ensemble_handle = ( - gen_model_ops.decision_tree_ensemble_resource_handle_op(name=name)) + gen_model_ops.decision_tree_ensemble_resource_handle_op( + self._ensemble_handle.op.name + "/local_ensemble")) create_op = gen_model_ops.create_tree_ensemble_variable( local_ensemble_handle, stamp_token=-1, tree_ensemble_config="") with ops.control_dependencies([create_op]): diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index 92068e88a76cb8bfdd394c1093347a8fb8a63449..7e45d0b2cecefa4bdec77d6cf7cfca7dba04db9c 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -43,7 +43,7 @@ from tensorflow.python.platform import googletest def _squared_loss(label, unused_weights, predictions): """Unweighted loss implementation.""" loss = math_ops.reduce_sum( - math_ops.square(predictions - label), 1, keepdims=True) + math_ops.squared_difference(predictions, label), 1, keepdims=True) return loss diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py index 220e981618b7c0bfb1e4e98c087d83b451b9b3cf..1ad40aca2880940c78d746674c7378ff0427c057 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py @@ -166,7 +166,7 @@ def per_example_squared_loss(labels, weights, predictions): update_op: An update operation to update the loss's internal state. """ unweighted_loss = math_ops.reduce_sum( - math_ops.square(predictions - labels), 1, keepdims=True) + math_ops.squared_difference(predictions, labels), 1, keepdims=True) return unweighted_loss * weights, control_flow_ops.no_op() diff --git a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h index 94aeb2c7bb48c6eddb6c7894f8bf6f1567470113..0fe57c0a4e8375cc7ec7aca9553bded87e238b33 100644 --- a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h @@ -34,7 +34,7 @@ class DecisionTreeEnsembleResource : public StampedResource { protobuf::Arena::CreateMessage< boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_)) {} - string DebugString() override { + string DebugString() const override { return strings::StrCat("GTFlowDecisionTreeEnsemble[size=", decision_tree_ensemble_->trees_size(), "]"); } diff --git a/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h b/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h index fdaaae7f472c8f564ab45a8366d3746cbf1158ee..574e3065e7f46049815897ef73e44d33f0d23f0f 100644 --- a/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h @@ -43,7 +43,7 @@ class QuantileStreamResource : public StampedResource { set_stamp(stamp_token); } - string DebugString() override { return "QuantileStreamResource"; } + string DebugString() const override { return "QuantileStreamResource"; } tensorflow::mutex* mutex() { return &mu_; } diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 94b7f4f867655bf7fdf94e8488eeae7088c41622..99ed4959fad9699f265183d71a1f3b609d7e6d30 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -51,11 +51,11 @@ from tensorflow.contrib.checkpoint.python.split_dependency import split_dependen from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph from tensorflow.python.training.checkpoint_management import CheckpointManager -from tensorflow.python.training.checkpointable.base import CheckpointableBase +from tensorflow.python.training.checkpointable.base import Checkpointable as CheckpointableBase from tensorflow.python.training.checkpointable.data_structures import List from tensorflow.python.training.checkpointable.data_structures import Mapping from tensorflow.python.training.checkpointable.data_structures import NoDependency -from tensorflow.python.training.checkpointable.tracking import Checkpointable +from tensorflow.python.training.checkpointable.tracking import AutoCheckpointable as Checkpointable from tensorflow.python.training.checkpointable.util import capture_dependencies from tensorflow.python.training.checkpointable.util import list_objects from tensorflow.python.training.checkpointable.util import object_metadata diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD index ada41687261ab63286933d01da4e286173042e0c..4e529322c7c76797938468b405cd175609dc0a73 100644 --- a/tensorflow/contrib/checkpoint/python/BUILD +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -2,7 +2,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "tf_py_test") py_library( name = "checkpoint", @@ -27,17 +27,17 @@ py_library( ], ) -py_test( +tf_py_test( name = "containers_test", srcs = ["containers_test.py"], - deps = [ + additional_deps = [ ":containers", + "@six_archive//:six", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", "//tensorflow/python:resource_variable_ops", "//tensorflow/python/training/checkpointable:base", "//tensorflow/python/training/checkpointable:util", - "@six_archive//:six", ], ) @@ -53,18 +53,18 @@ py_library( ], ) -py_test( +tf_py_test( name = "python_state_test", srcs = ["python_state_test.py"], - deps = [ + additional_deps = [ ":python_state", + "//third_party/py/numpy", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:session", "//tensorflow/python:variables", "//tensorflow/python/eager:test", "//tensorflow/python/training/checkpointable:util", - "//third_party/py/numpy", ], ) @@ -80,10 +80,10 @@ py_library( ], ) -py_test( +tf_py_test( name = "split_dependency_test", srcs = ["split_dependency_test.py"], - deps = [ + additional_deps = [ ":split_dependency", "//tensorflow/python:array_ops", "//tensorflow/python:framework_test_lib", @@ -106,10 +106,10 @@ py_library( ], ) -py_test( +tf_py_test( name = "visualize_test", srcs = ["visualize_test.py"], - deps = [ + additional_deps = [ ":visualize", "//tensorflow/python:constant_op", "//tensorflow/python:resource_variable_ops", diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py index 5418e2605b724edb60878e250d2c50fcc6ff5633..97936d9e9dfd5d6e62fdf8312707a276b63e1267 100644 --- a/tensorflow/contrib/checkpoint/python/containers.py +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -63,7 +63,7 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure): ValueError: If `checkpointable` is not a checkpointable object. """ - if not isinstance(checkpointable, checkpointable_lib.CheckpointableBase): + if not isinstance(checkpointable, checkpointable_lib.Checkpointable): raise ValueError( ("Expected a checkpointable value, got %s which does not inherit " "from CheckpointableBase.") % (checkpointable,)) diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py index ac85c7be803cd4c2f8ba19d3ef887a3c65a15933..a2d453ec6eb3dcf9aba4c52fe866756a92673c63 100644 --- a/tensorflow/contrib/checkpoint/python/containers_test.py +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -52,7 +52,7 @@ class UniqueNameTrackerTests(test.TestCase): save_root = util.Checkpoint(slots=slots) save_path = save_root.save(checkpoint_prefix) - restore_slots = tracking.Checkpointable() + restore_slots = tracking.AutoCheckpointable() restore_root = util.Checkpoint( slots=restore_slots) status = restore_root.restore(save_path) @@ -68,7 +68,7 @@ class UniqueNameTrackerTests(test.TestCase): @test_util.run_in_graph_and_eager_modes def testExample(self): - class SlotManager(tracking.Checkpointable): + class SlotManager(tracking.AutoCheckpointable): def __init__(self): self.slotdeps = containers.UniqueNameTracker() diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py index 302d5cfb79a08b6adf52ebd44533152c5454eadc..969c90c78871ebff02b360f8f09623df56c9c077 100644 --- a/tensorflow/contrib/checkpoint/python/python_state.py +++ b/tensorflow/contrib/checkpoint/python/python_state.py @@ -34,7 +34,7 @@ except ImportError: # pylint: enable=g-import-not-at-top -class NumpyState(base.CheckpointableBase): +class NumpyState(base.Checkpointable): """A checkpointable object whose NumPy array attributes are saved/restored. Example usage: @@ -130,7 +130,7 @@ class NumpyState(base.CheckpointableBase): @six.add_metaclass(abc.ABCMeta) -class PythonStateWrapper(base.CheckpointableBase): +class PythonStateWrapper(base.Checkpointable): """Wraps a Python object for storage in an object-based checkpoint.""" @abc.abstractmethod diff --git a/tensorflow/contrib/checkpoint/python/split_dependency.py b/tensorflow/contrib/checkpoint/python/split_dependency.py index 7e77453f3d848c2e321ed2ba66917a742d95459a..3e9700ad74618e24843181d169f3fb39ac96bff6 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency.py @@ -43,7 +43,7 @@ class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): return self._restore_callback(tensor) -class _SplitDependency(checkpointable.CheckpointableBase): +class _SplitDependency(checkpointable.Checkpointable): """Looks like a regular variable while synchronizing save/restores.""" def __init__(self, save_buffer, restore_buffer, name, dtype, num_components, diff --git a/tensorflow/contrib/checkpoint/python/split_dependency_test.py b/tensorflow/contrib/checkpoint/python/split_dependency_test.py index 00a805af25d5d0ea723db5d015fb12bf45c53857..664a4e76ab31bf31c7a57924e4af866f2d746804 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency_test.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency_test.py @@ -44,7 +44,7 @@ def _combine_variable_closure(variable): return _consume_restore_buffer_fn -class SaveTensorSlicesAsDeps(base.CheckpointableBase): +class SaveTensorSlicesAsDeps(base.Checkpointable): def __init__(self): self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.]) @@ -59,14 +59,14 @@ class SaveTensorSlicesAsDeps(base.CheckpointableBase): self._track_checkpointable(dep, name=name) -class HasRegularDeps(tracking.Checkpointable): +class HasRegularDeps(tracking.AutoCheckpointable): def __init__(self): self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) self.second_half = resource_variable_ops.ResourceVariable([0., 0.]) -class OnlyOneDep(tracking.Checkpointable): +class OnlyOneDep(tracking.AutoCheckpointable): def __init__(self): self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index 1311063ec023bdaa2588d6f1c826bf900f7dea09..20f8c2b2453a58fdbe5a3587fa6687debd9c06d3 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -27,7 +27,6 @@ tf_kernel_library( deps = [ ":bigquery_table_accessor", ":bigquery_table_partition_proto_cc", - "//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:reader_base", @@ -79,7 +78,6 @@ tf_kernel_library( srcs = ["gcs_config_ops.cc"], visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/contrib/cloud:gcs_config_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/platform/cloud:curl_http_request", diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index df8b48dfc46124d3b9454d92ffb70dbcf1bc4217..60ee1b4b3fd7d0b6afaefcc05effd3bbae00cf2c 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -147,19 +147,19 @@ suitable interface for project configuration and dependency setting. * Go (required if you need ssl support, optional) * NASM/YASM (required by grpc for ssl support, optional) 2. Start CMake GUI -3. Click on `Browse Source` and direct to the the folder +3. Click on `Browse Source` and direct to the folder `/tensorflow/contrib/cmake` 4. Click on `Browse Build` and spectify a location that you want tensorflow to be build 5. Click on `Configure`, a new window will be prompted out, specify the generator mode for the project generation. For Windows, choose `Visual Studio Win64`, for Linux, choose `Unix Makefiles`, then - press `Finish`. Wait for a moment, the default project dependecy would + press `Finish`. Wait for a moment, the default project dependency would automatically generate. 6. There are a few options that you can customize your own build. **The setting - here is crucial for a sucessful build, please check all items carefully.** + here is crucial for a successful build, please check all items carefully.** - * `tensorflow_BUILD_ALL_KERNELS` should alway be `on` + * `tensorflow_BUILD_ALL_KERNELS` should always be `on` * `tensorflow_BUILD_CC_EXAMPLE` is default to be `on`. This can help you to test build (optional) * `tensorflow_BUILD_CONTRIB_KERNELS` is default to be `on`, but it won't @@ -278,7 +278,7 @@ suitable interface for project configuration and dependency setting. `make -sj install` Where `` is the threads used for the compilation, change - to any integer less or equal to your computer's maxiumum thread number. + to any integer less or equal to your computer's maximum thread number. Headers are discretely located in the build folders. Tensorflow library can be found at ``, namely `tensorflow.so` (Linux) or diff --git a/tensorflow/contrib/cmake/external/abseil_cpp.cmake b/tensorflow/contrib/cmake/external/abseil_cpp.cmake index 46a193971c5084523d432065f265fa7a9909f595..6c6a5df7f76723800740a81ccdcb137a0ec33846 100644 --- a/tensorflow/contrib/cmake/external/abseil_cpp.cmake +++ b/tensorflow/contrib/cmake/external/abseil_cpp.cmake @@ -31,17 +31,17 @@ if (systemlib_ABSEIL_CPP) message(STATUS " abseil_cpp includes: ${ABSEIL_CPP_INCLUDE_DIR}") message(STATUS " abseil_cpp libraries: ${ABSEIL_CPP_LIBRARIES}") - add_custom_target(abseil_cpp) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp) + add_custom_target(abseil_cpp_build) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp_build) else (systemlib_ABSEIL_CPP) include (ExternalProject) - set(abseil_cpp_INCLUDE_DIR ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp) - set(abseil_cpp_URL https://github.com/abseil/abseil-cpp/archive/e01d95528ea2137a4a27a88d1f57c6cb260aafed.tar.gz) - set(abseil_cpp_HASH SHA256=84043ed402d2a2a6ba4cdddb7e85118b1158fd81fe4ac3a14adc343d054c1e2e) - set(abseil_cpp_BUILD ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp-build) + set(abseil_cpp_INCLUDE_DIR ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp_build) + set(abseil_cpp_URL https://github.com/abseil/abseil-cpp.git) + set(abseil_cpp_TAG master) + set(abseil_cpp_BUILD ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp_build) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") @@ -49,8 +49,11 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/base/Release/absl_base.lib ${abseil_cpp_BUILD}/absl/base/Release/absl_dynamic_annotations.lib ${abseil_cpp_BUILD}/absl/base/Release/absl_internal_malloc_internal.lib + ${abseil_cpp_BUILD}/absl/base/Release/absl_internal_throw_delegate.lib + ${abseil_cpp_BUILD}/absl/numeric/Release/absl_int128.lib ${abseil_cpp_BUILD}/absl/strings/Release/absl_strings.lib ${abseil_cpp_BUILD}/absl/strings/Release/str_format_internal.lib + ${abseil_cpp_BUILD}/absl/time/Release/absl_time.lib ${abseil_cpp_BUILD}/absl/types/Release/absl_bad_optional_access.lib) else() set(abseil_cpp_STATIC_LIBRARIES @@ -62,6 +65,7 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/numeric/absl_int128.lib ${abseil_cpp_BUILD}/absl/strings/absl_strings.lib ${abseil_cpp_BUILD}/absl/strings/str_format_internal.lib + ${abseil_cpp_BUILD}/absl/time/absl_time.lib ${abseil_cpp_BUILD}/absl/types/absl_bad_optional_access.lib) endif() else() @@ -74,15 +78,18 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/numeric/libabsl_int128.a ${abseil_cpp_BUILD}/absl/strings/libabsl_strings.a ${abseil_cpp_BUILD}/absl/strings/libstr_format_internal.a + ${abseil_cpp_BUILD}/absl/time/libabsl_time.a ${abseil_cpp_BUILD}/absl/types/libabsl_bad_optional_access.a) endif() - ExternalProject_Add(abseil_cpp + ExternalProject_Add(abseil_cpp_build PREFIX abseil_cpp - URL ${abseil_cpp_URL} - URL_HASH ${abseil_cpp_HASH} + GIT_REPOSITORY ${abseil_cpp_URL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${abseil_cpp_STATIC_LIBRARIES} + BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release + COMMAND ${CMAKE_COMMAND} --build . --config Release INSTALL_COMMAND "" CMAKE_CACHE_ARGS -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} @@ -91,8 +98,10 @@ else (systemlib_ABSEIL_CPP) ) include_directories(${abseil_cpp_INCLUDE_DIR}) + message(STATUS ${abseil_cpp_INCLUDE_DIR}) + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${abseil_cpp_STATIC_LIBRARIES}) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp_build) endif (systemlib_ABSEIL_CPP) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index 479609458c64f7c7bd7b3ce6b23aceaa3db17f21..b15143bfc1cd787b156c9d6dd724a17730f0f8fb 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 1.20.1) +set(nsync_TAG 1.20.2) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 96160568fa79291a7b391761373e1eaf0f70974e..21ae9a08a6bb8f71e5935ddde2d7bb3ed0cd8bbc 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -1,6 +1,9 @@ # python_sanity_test.py will complain about invalid or missing entries # problematic entries can be commented for temporary whitelisting tensorflow +tensorflow/compiler +tensorflow/compiler/xla +tensorflow/compiler/xla/service tensorflow/core tensorflow/core/example tensorflow/core/framework diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index d7b2a1339e047aba0a9424a53a63726805e89721..d8d1cc3aa2ca4fff3c950654b7cbd7085c76010c 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -147,7 +147,6 @@ set(tf_proto_text_srcs "tensorflow/core/framework/function.proto" "tensorflow/core/framework/graph.proto" "tensorflow/core/framework/graph_transfer_info.proto" - "tensorflow/core/framework/iterator.proto" "tensorflow/core/framework/kernel_def.proto" "tensorflow/core/framework/log_memory.proto" "tensorflow/core/framework/node_def.proto" @@ -302,8 +301,8 @@ file(GLOB_RECURSE tf_core_framework_srcs "${tensorflow_source_dir}/tensorflow/core/common_runtime/session.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_factory.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_options.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*.h" + "${tensorflow_source_dir}/tensorflow/core/summary/*.cc" + "${tensorflow_source_dir}/tensorflow/core/summary/*.h" "${tensorflow_source_dir}/public/*.h" ) @@ -317,14 +316,14 @@ file(GLOB_RECURSE tf_core_framework_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/util/*test*.h" "${tensorflow_source_dir}/tensorflow/core/util/*test*.cc" "${tensorflow_source_dir}/tensorflow/core/util/*main.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*test*.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/loader.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/vacuum.cc" + "${tensorflow_source_dir}/tensorflow/core/summary/*test*.cc" + "${tensorflow_source_dir}/tensorflow/core/summary/loader.cc" + "${tensorflow_source_dir}/tensorflow/core/summary/vacuum.cc" ) # TODO(jart): Why doesn't this work? # set_source_files_properties( -# ${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/snapfn.cc +# ${tensorflow_source_dir}/tensorflow/core/lib/db/snapfn.cc # PROPERTIES COMPILE_FLAGS -DSQLITE_OMIT_LOAD_EXTENSION) list(REMOVE_ITEM tf_core_framework_srcs ${tf_core_framework_exclude_srcs}) diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 8faccf8d55902e6701ebb4ce534b84705304fd5f..1fe8795ddf00232eba5a60a130e0845a6f6a8e17 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -802,6 +802,7 @@ add_custom_command( # tensorflow/__init__.py depends on files generated in this step. So, remove it while # this step is running since the files aren't there yet. COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py # Run create_python_api.py to generate API init files. COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python "${PY_RUNTIME_ENV}" ${PYTHON_EXECUTABLE} diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py index f867cd15b67dbd43650d8012b4299845af7200a8..0f1be500f499ebba7e1907de663f8bbfa889bb17 100644 --- a/tensorflow/contrib/compiler/xla.py +++ b/tensorflow/contrib/compiler/xla.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import function_utils +from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect @@ -76,10 +77,22 @@ def compile(computation, inputs=None): # pylint: disable=redefined-builtin All `Operation`s returned from `computation` will be executed when evaluating any of the returned output tensors. - inputs: A list of input tensors or `None` (equivalent to an empty list). + inputs: A list of inputs or `None` (equivalent to an empty list). Each input + can be a nested structure containing values that are convertible to + tensors. Note that passing an N-dimension list of compatible values will + result in a N-dimention list of scalar tensors rather than a single Rank-N + tensors. If you need different behavior, convert part of inputs to tensors + with `tf.convert_to_tensor`. Returns: - A list of output tensors. + Same data structure as if computation(*inputs) is called directly with some + exceptions for correctness. Exceptions include: + 1) None output: a NoOp would be returned which control-depends on + computation. + 2) Single value output: A tuple containing the value would be returned. + 3) Operation-only outputs: a NoOp would be returned which + control-depends on computation. + TODO(b/121383831): Investigate into removing these special cases. """ # pylint: disable=protected-access return _compile_internal(computation, inputs) @@ -245,13 +258,21 @@ def _compile_internal(computation, inputs=None): Args: computation: A Python function that builds the computation to compile and execute. - inputs: A list of input tensors or `None` (equivalent to `[]`). Its order - should match ordering of computation arguments. + inputs: A list of inputs or `None` (equivalent to an empty list). Each input + can be a nested structure containing values that are convertible to + tensors. Note that passing an N-dimension list of compatible values will + result in a N-dimension list of scalar tensors rather than a single Rank-N + tensors. If you need different behavior, convert part of inputs to tensors + with `tf.convert_to_tensor`. + Returns: - A list of output tensors from computation. + Same data structure as if computation(*inputs) is called directly with some + exceptions for correctness. Exceptions include: 1) None output 2) Single + value output 3) Operation-only outputs Raises: ValueError: If any element in computation outputs is neither an operations or a value that can be converted to tensor. + ValueError: If computation outputs is non-flat and contains any Operations. TypeError: If `inputs` is not a list or tuple. """ if inputs is None: @@ -260,17 +281,10 @@ def _compile_internal(computation, inputs=None): if not isinstance(inputs, collections.Sequence): raise TypeError('inputs must be a list') + # Flatten inputs. + flat_inputs = nest.flatten(inputs) # Converts inputs to Tensors. - inputs = [ops.convert_to_tensor(x) for x in inputs] - input_arity = len(inputs) - - arg_error = check_function_argument_count( - computation, input_arity, infeed_queue=None) - if arg_error is not None: - raise TypeError( - 'Supplied computation cannot be called with the specified inputs. You ' - 'specified %d inputs: %s, but the computation needs %s' % - (input_arity, str([i.name for i in inputs]), arg_error)) + flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs] cluster_name = ops.get_default_graph().unique_name('cluster') pivot = control_flow_ops.no_op(name=cluster_name + '/pivot') @@ -280,11 +294,15 @@ def _compile_internal(computation, inputs=None): # Add identity ops so even unused inputs are 'consumed' by the # computation. - computation_inputs = [ + flat_inputs = [ array_ops.identity(x, name='input_{}'.format(i)) - for i, x in enumerate(inputs) + for i, x in enumerate(flat_inputs) ] + # Re-pack flat_inputs in same structure as 'inputs'. + computation_inputs = nest.pack_sequence_as( + structure=inputs, flat_sequence=flat_inputs) + # Only resource variables work inside an XLA computation, so turn on # resource variables for the computation. vscope = variable_scope.get_variable_scope() @@ -297,66 +315,166 @@ def _compile_internal(computation, inputs=None): # Restore variable scope after computation. vscope.set_use_resource(saved_use_resource) - # If the computation returns `None`, make it an empty tuple. - if outputs is None: - outputs = tuple() - # If the computation only returned one value, make it a tuple. - if not isinstance(outputs, collections.Sequence): - outputs = (outputs,) - - # Append `no_op` here so that return value of this function always contains - # at least one op that can trigger XlaLaunch node. - outputs += (control_flow_ops.no_op(),) - try: - outputs = [ - o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) - for o in outputs - ] - except Exception as e: - raise ValueError( - 'XLA computation function return values must all either be Operations' - ' or convertible to Tensors. Got error: "%s"' % str(e)) - - # Separates the returned Operations and Tensors. - output_operations = [o for o in outputs if isinstance(o, ops.Operation)] - output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] - - if outputs != output_tensors + output_operations: - raise ValueError( - 'XLA computation function must return zero or more Tensor values ' - 'followed by zero or more Operations.') - output_arity = len(output_tensors) - - new_output_tensors = [] - for t in output_tensors: - with ops.device(t.device if t.device else ''): - new_output_tensors.append(array_ops.identity(t)) + outputs_is_flat = is_flat(outputs) + if outputs_is_flat: + output_tensors, control_deps = _postprocess_flat_outputs(outputs) + else: + output_tensors, control_deps = _postprocess_non_flat_outputs(outputs) - output_tensors = new_output_tensors context.ExitResult(output_tensors) finally: context.report_unsupported_operations() context.Exit() - outputs = [ - xla_ops.xla_cluster_output(output_tensors[i], name='output{}'.format(i)) - for i in xrange(output_arity) + # When XLA computation returns only operations and no tensors, a NoOp + # dependent on the operations in outputs is returned. Otherwise final + # outputs would be empty and there is no way to trigger returned + # operations. + if not output_tensors: + return control_flow_ops.group(control_deps, name='output_0') + + output_tensors = [ + xla_ops.xla_cluster_output(o, name='output{}'.format(i)) + for i, o in enumerate(output_tensors) ] - with ops.control_dependencies(output_operations): - if output_arity == 0: - # When XLA computation returns only operations and no tensors, a NoOp - # dependent on the operations in outputs is returned. Otherwise final - # outputs would be empty and there is no way to trigger returned - # operations. - return control_flow_ops.no_op(name='output_0') - else: - # Wraps the outputs in identity operators that carries control - # dependencies. - return [ - array_ops.identity(outputs[i], name='output_%d' % i) - for i in xrange(output_arity) - ] + with ops.control_dependencies(control_deps): + # Wraps the outputs in identity operators that carries control + # dependencies. + output_tensors = [ + array_ops.identity(o, name='output_%d' % i) + for i, o in enumerate(output_tensors) + ] + + # If `computation` returned non-flat output structure, pack output tensors + # back into same structure. + if not outputs_is_flat: + output_tensors = nest.pack_sequence_as( + structure=outputs, flat_sequence=output_tensors) + + return output_tensors + + +def is_flat(outputs): + """Checks if outputs is a flat structure. + + Following structures and values are considered flat: + 1) None + 2) A single object + 3) A list or tuple of Tensors/Operations + + The only structures that this function understands are sequences and + dictionaries. E.g. this means that if outputs contains a single + user-defined Object, it is considered to be flat. Errors are raised later on + if that Object cannot be converted to a Tensor. + + Args: + outputs: Output from `computation` inside `xla.compile`. + + Returns: + A boolean indicates whether outputs is flat. + """ + # If outputs is a list or tuple, check if it has any nested structure. If + # there is, then outputs is non-flat. + if isinstance(outputs, collections.Sequence): + for o in outputs: + if isinstance(o, collections.Sequence) or isinstance(o, dict): + return False + + # If outputs is a dict, it is non-flat. + if isinstance(outputs, dict): + return False + + # Getting here means either outputs itself is a single non-structured value + # or it is a flat list of single non-structured values. + return True + + +def _postprocess_flat_outputs(outputs): + """Validates flat outputs and adds back device assignments. + + Args: + outputs: Output from `computation` inside `xla.compile`. + + Returns: + Tensors and Operations extracted from outputs. + """ + # Following code segment is to preserve legacy behavior. Previously we only + # supported flat outputs and thus for consistency it was nice to convert even + # single element into a tuple. But now that we support arbitrary output + # structure, this is no longer necessary. + # TODO(b/121383831): Migrate all legacy use cases and delete this special + # case. + # If the computation returns `None`, make it an empty tuple. + if outputs is None: + outputs = tuple() + # If the computation only returned one value, make it a tuple. + if not isinstance(outputs, collections.Sequence): + outputs = (outputs,) + + # Append `no_op` here so that return value of this function always contains + # at least one op that can trigger XlaLaunch node. + outputs += (control_flow_ops.no_op(),) + try: + outputs = [ + o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) + for o in outputs + ] + except Exception as e: + raise ValueError( + 'XLA computation function return values must all either be Operations' + ' or convertible to Tensors. Got error: "%s"' % str(e)) + + # Separates the returned Operations and Tensors. + output_operations = [o for o in outputs if isinstance(o, ops.Operation)] + output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] + + if outputs != output_tensors + output_operations: + raise ValueError( + 'XLA computation function must return zero or more Tensor values ' + 'followed by zero or more Operations.') + + new_output_tensors = [] + for t in output_tensors: + with ops.device(t.device if t.device else ''): + new_output_tensors.append(array_ops.identity(t)) + + return new_output_tensors, output_operations + + +def _postprocess_non_flat_outputs(outputs): + """Validates non-flat outputs and adds back device assignments. + + Args: + outputs: Output from `computation` inside `xla.compile`. + + Returns: + Tensors extracted from outputs and an empty list because Operations are not + allowed in non-flat outputs.. + """ + # Convert all non-Operation outputs to Tensors. + new_output_tensors = [] + for o in nest.flatten(outputs): + if isinstance(o, ops.Operation): + raise ValueError( + 'xla.compile does not support Operation as return value in non-flat ' + 'output structure. You can set returned Operations as control ' + 'dependencies of returned Tensors so Operations are triggered when ' + 'Tensors are evaluated. Operation found: "%s"' % o.name) + + try: + o = ops.convert_to_tensor(o) + except Exception as e: + raise ValueError( + 'XLA computation function return values must all either be ' + 'Operations or convertible to Tensors. Got error: "%s"' % str(e)) + + # Makes sure even pass-through inputs/outputs are touched in compile + # context by creating an Identity node inside compile context. + with ops.device(o.device if o.device else ''): + new_output_tensors.append(array_ops.identity(o)) + + return new_output_tensors, [] @contextlib.contextmanager diff --git a/tensorflow/contrib/constrained_optimization/BUILD b/tensorflow/contrib/constrained_optimization/BUILD index eee4329acbeb38c9f37f79227aeb3acd46dce5e7..619153df67c90cea5a5082a411972948bac5fe90 100644 --- a/tensorflow/contrib/constrained_optimization/BUILD +++ b/tensorflow/contrib/constrained_optimization/BUILD @@ -42,11 +42,6 @@ py_test( name = "candidates_test", srcs = ["python/candidates_test.py"], srcs_version = "PY2AND3", - tags = [ - # TODO(b/121223093): Re-enable this test after fixing "Distribution - # should match known solution" errors. - "no_mac", - ], deps = [ ":constrained_optimization", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/constrained_optimization/python/candidates_test.py b/tensorflow/contrib/constrained_optimization/python/candidates_test.py index a4c49d48bc5c763489215261a909573af0f19055..280e9acd88638a9385bfd9128ba6d3739879aab2 100644 --- a/tensorflow/contrib/constrained_optimization/python/candidates_test.py +++ b/tensorflow/contrib/constrained_optimization/python/candidates_test.py @@ -52,12 +52,12 @@ class CandidatesTest(test.TestCase): distribution = candidates.find_best_candidate_distribution( objective_vector, constraints_matrix) # Verify that the solution is a probability distribution. - self.assertTrue(np.all(distribution >= 0)) + self.assertTrue(np.all(distribution >= -1e-6)) self.assertAlmostEqual(np.sum(distribution), 1.0) # Verify that the solution satisfies the constraints. maximum_constraint_violation = np.amax( np.dot(constraints_matrix, distribution)) - self.assertLessEqual(maximum_constraint_violation, 0) + self.assertLessEqual(maximum_constraint_violation, 1e-6) # Verify that the solution matches that which we expect. expected_distribution = np.array([0.37872711, 0.62127289, 0, 0]) self.assertAllClose(expected_distribution, distribution, rtol=0, atol=1e-6) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index 7e1b4062ce435f3ab4216e90b4f5fcbab984c1dc..ca92c31236a7a3882415834eb32a994a120b6d2d 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -1023,7 +1023,7 @@ class CudnnRNNTestCompatibleRNNCells(test_util.TensorFlowTestCase): outputs_v, output_state_v = sess.run( [outputs, output_state], feed_dict={cell_inputs: inference_input}) - self.assertAllClose(cudnn_outputs_v, outputs_v, atol=2e-5, rtol=2e-5) + self.assertAllClose(cudnn_outputs_v, outputs_v, atol=1e-4, rtol=2e-4) (cudnn_output_h_v,) = cudnn_output_states_v self.assertAllClose(cudnn_output_h_v, output_state_v, atol=2e-5, rtol=2e-5) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 1facc83972faf229f243af5bc534bcb98aff5440..f36e8d5022bc7e3f8268a161089153e5510dffc6 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -837,7 +837,7 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access assert len(biases) == len(weights) for cell_index, (bias, kernel) in enumerate(zip(biases, weights)): - cell = checkpointable_lib.Checkpointable() + cell = checkpointable_lib.AutoCheckpointable() checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access cell.bias = bias cell.kernel = kernel diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index 8a8dc159ade6f2a4a9b5ec29055ea4848492b29f..dbcaf8185fb7a9d2bcf22376439c0ebd49accb1a 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -43,28 +43,19 @@ the workers. Let's see how to scale to multiple GPUs on one machine using `MirroredStrategy` with [tf.keras] (https://www.tensorflow.org/guide/keras). -Take a very simple model consisting of a single layer: +Let's define a simple input dataset for training this model. Note that currently we require using +[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) +with `DistributionStrategy`. ```python import tensorflow as tf from tensorflow import keras -inputs = tf.keras.layers.Input(shape=(1,)) -predictions = tf.keras.layers.Dense(1)(inputs) -model = tf.keras.models.Model(inputs=inputs, outputs=predictions) -``` - -Let's also define a simple input dataset for training this model. Note that currently we require using -[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) -with `DistributionStrategy`. - -```python features = tf.data.Dataset.from_tensors([1.]).repeat(10000).batch(10) labels = tf.data.Dataset.from_tensors([1.]).repeat(10000).batch(10) train_dataset = tf.data.Dataset.zip((features, labels)) ``` - To distribute this Keras model on multiple GPUs using `MirroredStrategy` we first instantiate a `MirroredStrategy` object. @@ -72,14 +63,17 @@ first instantiate a `MirroredStrategy` object. distribution = tf.contrib.distribute.MirroredStrategy() ``` -We then compile the Keras model and pass the `MirroredStrategy` object in the -`distribute` argument (apart from other usual arguments like `loss` and -`optimizer`). +Take a very simple model consisting of a single layer. We need to create and compile +the model under the distribution strategy scope. ```python -model.compile(loss='mean_squared_error', - optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.2), - distribute=distribution) +with distribution.scope(): + inputs = tf.keras.layers.Input(shape=(1,)) + predictions = tf.keras.layers.Dense(1)(inputs) + model = tf.keras.models.Model(inputs=inputs, outputs=predictions) + + model.compile(loss='mean_squared_error', + optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.2)) ``` To train the model we call Keras `fit` API using the input dataset that we diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index 8ec73654e30e4967f318c558ba94301e84a206e4..59d76f5d1c817d7f2cc8ad285b9fb517fe994a81 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -30,12 +30,13 @@ from tensorflow.contrib.distribute.python.monitor import Monitor from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy from tensorflow.contrib.distribute.python.step_fn import * +from tensorflow.contrib.distribute.python.tpu_strategy import initialize_tpu_system from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy from tensorflow.python.distribute.cross_device_ops import * from tensorflow.python.distribute.distribute_config import DistributeConfig from tensorflow.python.distribute.distribute_coordinator import run_standard_tensorflow_server -from tensorflow.python.training.distribute import * -from tensorflow.python.training.distribution_strategy_context import * +from tensorflow.python.distribute.distribute_lib import * +from tensorflow.python.distribute.distribution_strategy_context import * from tensorflow.python.util.all_util import remove_undocumented @@ -58,11 +59,14 @@ _allowed_symbols = [ 'StandardSingleLossStep', 'ReplicaContext', 'TPUStrategy', + 'initialize_tpu_system', 'get_cross_replica_context', 'get_distribution_strategy', 'get_loss_reduction', 'get_replica_context', + 'get_strategy', 'has_distribution_strategy', + 'has_strategy', 'in_cross_replica_context', 'require_replica_context', 'run_standard_tensorflow_server', diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index d2fb878f96f55200d870447b45f3d0a37c6b0f86..1b455a4e644417561a7556e66465f2cb093776d5 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -1,5 +1,7 @@ # Implementation of a prototype TF distributed computation library. +load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") +load("//tensorflow/core:platform/default/distribute.bzl", "distribute_py_test") load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") @@ -13,8 +15,18 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -# TODO(priyag): Figure out testonly issues that are preventing us from -# including our tests in pip for now. +py_library( + name = "distribute_test_lib_pip", + visibility = ["//tensorflow:internal"], + deps = [ + ":combinations", + ":keras_correctness_test_lib", + ":keras_test_lib", + ":multi_worker_test_base", + ":single_loss_example", + ":strategy_test_lib", + ], +) cuda_py_test( name = "values_test", @@ -22,25 +34,36 @@ cuda_py_test( additional_deps = [ ":combinations", ":mirrored_strategy", - ":multi_worker_test_base", "@absl_py//absl/testing:parameterized", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", - "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:training", "//tensorflow/python:variable_scope", - "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/distribute:device_util", "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "//tensorflow/python/estimator:estimator_py", ], - tags = [ - "no_pip", +) + +cuda_py_test( + name = "input_lib_test", + srcs = ["input_lib_test.py"], + additional_deps = [ + ":combinations", + ":mirrored_strategy", + ":multi_worker_test_base", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/distribute:input_lib", + "//tensorflow/python/distribute:values", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", ], ) @@ -50,8 +73,8 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:input_lib", "//tensorflow/python/distribute:mirrored_strategy", - "//tensorflow/python/distribute:values", ], ) @@ -60,18 +83,10 @@ py_library( srcs = ["parameter_server_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":mirrored_strategy", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/distribute:cross_device_ops", - "//tensorflow/python/distribute:multi_worker_util", - "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:parameter_server_strategy", "//tensorflow/python/distribute:values", - "//tensorflow/python/eager:context", + "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", ], ) @@ -104,7 +119,6 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) @@ -118,6 +132,8 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:input_lib", + "//tensorflow/python/distribute:numpy_dataset", "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", @@ -138,7 +154,9 @@ py_library( "//tensorflow/python:training", "//tensorflow/python/distribute:cross_device_ops", "//tensorflow/python/distribute:cross_device_utils", + "//tensorflow/python/distribute:input_lib", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/distribute:numpy_dataset", "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", ], @@ -146,12 +164,8 @@ py_library( py_library( name = "strategy_test_lib", - testonly = 1, srcs = ["strategy_test_lib.py"], srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -164,17 +178,14 @@ py_library( "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", + "//third_party/py/numpy", ], ) py_library( name = "combinations", - testonly = 1, srcs = ["combinations.py"], srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], deps = [ ":mirrored_strategy", ":one_device_strategy", @@ -186,6 +197,7 @@ py_library( "//tensorflow/python:util", "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/eager:context", + "//tensorflow/python/keras/optimizer_v2", "@absl_py//absl/testing:parameterized", ], ) @@ -193,9 +205,6 @@ py_library( py_test( name = "combinations_test", srcs = ["combinations_test.py"], - tags = [ - "no_pip", - ], deps = [ ":combinations", "//tensorflow/python/eager:test", @@ -206,9 +215,6 @@ py_test( name = "one_device_strategy_test", srcs = ["one_device_strategy_test.py"], srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], deps = [ ":one_device_strategy", ":strategy_test_lib", @@ -242,18 +248,13 @@ cuda_py_test( tags = [ "guitar", "multi_and_single_gpu", - "no_pip", ], ) py_library( name = "multi_worker_test_base", - testonly = 1, srcs = ["multi_worker_test_base.py"], srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", @@ -288,6 +289,8 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:tensor_util", "//tensorflow/python:util", + "//tensorflow/python/distribute:input_lib", + "//tensorflow/python/distribute:numpy_dataset", "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute:values", ], @@ -320,14 +323,16 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) -py_library( - name = "minimize_loss_test_lib", - testonly = 1, +distribute_py_test( + name = "minimize_loss_test", srcs = ["minimize_loss_test.py"], + main = "minimize_loss_test.py", + tags = [ + "multi_and_single_gpu", + ], deps = [ ":combinations", ":mirrored_strategy", @@ -347,18 +352,6 @@ py_library( ], ) -cuda_py_test( - name = "minimize_loss_test", - srcs = ["minimize_loss_test.py"], - additional_deps = [ - ":minimize_loss_test_lib", - ], - tags = [ - "multi_and_single_gpu", - "no_pip", - ], -) - cuda_py_test( name = "moving_averages_test", srcs = ["moving_averages_test.py"], @@ -372,9 +365,6 @@ cuda_py_test( "//tensorflow/python:training", "//tensorflow/python:variables", ], - tags = [ - "no_pip", - ], ) cuda_py_test( @@ -392,7 +382,6 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) @@ -415,7 +404,6 @@ cuda_py_test( tags = [ "multi_and_single_gpu", "no_oss", # http://b/119349471 - "no_pip", "tf_integration_test", ], ) @@ -429,7 +417,6 @@ cuda_py_test( tags = [ "multi_and_single_gpu", "no_oss", # http://b/119349471 - "no_pip", "tf_integration_test", ], ) @@ -459,7 +446,6 @@ cuda_py_test( shard_count = 48, tags = [ "multi_and_single_gpu", - "no_pip", # TODO(b/118768923): Re-enable {a,m,t}san test. "noasan", "nomsan", @@ -481,10 +467,13 @@ py_library( ], ) -py_library( - name = "step_fn_test_lib", - testonly = 1, +distribute_py_test( + name = "step_fn_test", srcs = ["step_fn_test.py"], + main = "step_fn_test.py", + tags = [ + "multi_and_single_gpu", + ], deps = [ ":combinations", ":single_loss_example", @@ -497,18 +486,6 @@ py_library( ], ) -cuda_py_test( - name = "step_fn_test", - srcs = ["step_fn_test.py"], - additional_deps = [ - ":step_fn_test_lib", - ], - tags = [ - "multi_and_single_gpu", - "no_pip", - ], -) - py_library( name = "monitor", srcs = ["monitor.py"], @@ -536,7 +513,6 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) @@ -553,9 +529,6 @@ cuda_py_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], - tags = [ - "no_pip", - ], ) cuda_py_test( @@ -577,13 +550,11 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) py_library( name = "keras_test_lib", - testonly = 1, srcs = [ "keras_backward_compat_test.py", "keras_test.py", @@ -602,43 +573,50 @@ py_library( ], ) -cuda_py_test( +distribute_py_test( name = "keras_test", srcs = ["keras_test.py"], - additional_deps = [ - ":keras_test_lib", - ], + main = "keras_test.py", shard_count = 16, tags = [ "multi_and_single_gpu", "no_oss", # TODO(b/117919883): Fix python error. - "no_pip", "no_windows_gpu", "notsan", ], + deps = [ + ":keras_test_lib", + ], ) # TODO(b/121200287): Remove this in 2.0 -cuda_py_test( +distribute_py_test( name = "keras_backward_compat_test", srcs = ["keras_backward_compat_test.py"], - additional_deps = [ - ":keras_test_lib", - ], - shard_count = 16, + full_precision = True, + main = "keras_backward_compat_test.py", + shard_count = 31, tags = [ "multi_and_single_gpu", "no_oss", # TODO(b/117919883): Fix python error. - "no_pip", "no_windows_gpu", "notsan", ], + deps = [ + ":keras_test_lib", + ], ) py_library( name = "keras_correctness_test_lib", - testonly = 1, - srcs = ["keras_correctness_test.py"], + srcs = [ + "keras_correctness_test_base.py", + "keras_dnn_correctness_test.py", + "keras_embedding_model_correctness_test.py", + "keras_image_model_correctness_test.py", + "keras_lstm_model_correctness_test.py", + "keras_stateful_lstm_model_correctness_test.py", + ], deps = [ ":combinations", "//tensorflow/contrib/distribute/python:mirrored_strategy", @@ -653,13 +631,95 @@ py_library( ], ) -cuda_py_test( - name = "keras_correctness_test", - srcs = ["keras_correctness_test.py"], - additional_deps = [ +distribute_py_test( + name = "keras_dnn_correctness_test", + size = "medium", + srcs = ["keras_dnn_correctness_test.py"], + full_precision = True, + main = "keras_dnn_correctness_test.py", + # Shard count is set to an odd number to distribute tasks across + # shards more evenly. + shard_count = 19, + tags = [ + "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. + "no_windows_gpu", + "notsan", + ], + deps = [ ":keras_correctness_test_lib", ], - shard_count = 16, +) + +distribute_py_test( + name = "keras_image_model_correctness_test", + size = "medium", + srcs = ["keras_image_model_correctness_test.py"], + full_precision = True, + main = "keras_image_model_correctness_test.py", + # Shard count is set to an odd number to distribute tasks across + # shards more evenly. + shard_count = 31, + tags = [ + "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. + "no_windows_gpu", + "notsan", + ], + deps = [ + ":keras_correctness_test_lib", + ], +) + +distribute_py_test( + name = "keras_embedding_model_correctness_test", + size = "medium", + srcs = ["keras_embedding_model_correctness_test.py"], + full_precision = True, + main = "keras_embedding_model_correctness_test.py", + # Shard count is set to an odd number to distribute tasks across + # shards more evenly. + shard_count = 31, + tags = [ + "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. + "no_windows_gpu", + "notsan", + ], + deps = [ + ":keras_correctness_test_lib", + ], +) + +distribute_py_test( + name = "keras_lstm_model_correctness_test", + size = "medium", + srcs = ["keras_lstm_model_correctness_test.py"], + full_precision = True, + main = "keras_lstm_model_correctness_test.py", + # Shard count is set to an odd number to distribute tasks across + # shards more evenly. + shard_count = 31, + tags = [ + "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. + "no_windows_gpu", + "notsan", + ], + deps = [ + ":keras_correctness_test_lib", + ], +) + +distribute_py_test( + name = "keras_stateful_lstm_model_correctness_test", + size = "medium", + srcs = ["keras_stateful_lstm_model_correctness_test.py"], + full_precision = True, + main = "keras_stateful_lstm_model_correctness_test.py", + # Shard count is set to an odd number to distribute tasks across + # shards more evenly. + shard_count = 31, tags = [ "multi_and_single_gpu", "no_oss", # TODO(b/117919883): Fix python error. @@ -667,12 +727,18 @@ cuda_py_test( "no_windows_gpu", "notsan", ], + deps = [ + ":keras_correctness_test_lib", + ], ) -py_library( - name = "metrics_v1_test_lib", - testonly = 1, +distribute_py_test( + name = "metrics_v1_test", srcs = ["metrics_v1_test.py"], + main = "metrics_v1_test.py", + tags = [ + "multi_and_single_gpu", + ], deps = [ ":combinations", "//tensorflow/python:math_ops", @@ -684,18 +750,6 @@ py_library( ], ) -cuda_py_test( - name = "metrics_v1_test", - srcs = ["metrics_v1_test.py"], - additional_deps = [ - ":metrics_v1_test_lib", - ], - tags = [ - "multi_and_single_gpu", - "no_pip", - ], -) - cuda_py_test( name = "warm_starting_util_test", size = "medium", @@ -710,7 +764,6 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) @@ -729,6 +782,25 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", + ], +) + +tf_xla_py_test( + name = "checkpointing_test", + srcs = ["checkpointing_test.py"], + disabled_backends = [ + # Only makes sense on TPUs + "cpu", + "gpu", + "cpu_ondemand", + ], + tags = [ + "no_oss", + ], + deps = [ + ":tpu_strategy", + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/python/eager:test", + "//tensorflow/python/training/checkpointable:util", ], ) diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py index 31bd0e996a247a2fc01405fb3b8172a40853d698..3ef8b9574a36730dcc1d8fd42b6c7f364d84bbed 100644 --- a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py +++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py @@ -71,7 +71,7 @@ class CheckpointUtilsWithDistributionStrategyTest( with ops.Graph().as_default() as g, distribution.scope(): if in_replica_mode: - distribution.call_for_each_replica(init_and_verify, args=[g]) + distribution.extended.call_for_each_replica(init_and_verify, args=[g]) else: init_and_verify(g) diff --git a/tensorflow/contrib/distribute/python/checkpointing_test.py b/tensorflow/contrib/distribute/python/checkpointing_test.py new file mode 100644 index 0000000000000000000000000000000000000000..aa5b9f57b8a5bc12ee94399ec1fc5a55177a5b5d --- /dev/null +++ b/tensorflow/contrib/distribute/python/checkpointing_test.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os + +from tensorflow.compiler.tests import xla_test +from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core +from tensorflow.python.platform import test +from tensorflow.python.training import adam as adam_v1 +from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import training_util +from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.checkpointable import util as checkpointable_utils + + +class NonLayerCheckpointable(tracking.AutoCheckpointable): + + def __init__(self): + super(NonLayerCheckpointable, self).__init__() + self.a_variable = checkpointable_utils.add_variable( + self, name="a_variable", shape=[]) + + +class Subclassed(training.Model): + """A concrete Model for testing.""" + + def __init__(self): + super(Subclassed, self).__init__() + self._named_dense = core.Dense(1, use_bias=True) + self._second = core.Dense(1, use_bias=False) + # We can still track Checkpointables which aren't Layers. + self._non_layer = NonLayerCheckpointable() + + def call(self, values): + ret = self._second(self._named_dense(values)) + return ret + + +class TrainingCheckpointTests(xla_test.XLATestCase): + + def testEagerTPUDistributionStrategy(self): + self.skipTest("b/121387144") + num_training_steps = 10 + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + + def _train_fn(optimizer, model): + input_value = constant_op.constant([[3.]]) + optimizer.minimize( + functools.partial(model, input_value), + global_step=root.optimizer_step) + + for training_continuation in range(3): + strategy = tpu_strategy.TPUStrategy() + with strategy.scope(): + model = Subclassed() + optimizer = adam_v1.AdamOptimizer(0.001) + root = checkpointable_utils.Checkpoint( + optimizer=optimizer, model=model, + optimizer_step=training_util.get_or_create_global_step()) + root.restore(checkpoint_management.latest_checkpoint( + checkpoint_directory)) + + for _ in range(num_training_steps): + strategy.extended.call_for_each_replica( + functools.partial(_train_fn, optimizer, model)) + root.save(file_prefix=checkpoint_prefix) + self.assertEqual((training_continuation + 1) * num_training_steps, + root.optimizer_step.numpy()) + + +if __name__ == "__main__": + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index e6bbf0c308a6abb6bbddb5ca9291a641ad518501..aa4d82b4d0c0dffc66115346d5f82a9d64bcfa56 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -26,9 +26,12 @@ from tensorflow.python.distribute import cross_device_ops as cross_device_ops_li from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import values from tensorflow.python.eager import context +from tensorflow.python.eager import tape from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops @@ -85,9 +88,11 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): else: local_devices = ("/device:CPU:0",) self._worker_device = device_util.canonicalize("/device:CPU:0") + self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) self._collective_keys = cross_device_utils.CollectiveKeys() self._initialize_local(local_devices) + # TODO(yuefengz): remove num_gpus_per_worker from CollectiveAllReduce. self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( num_workers=self._num_workers, num_gpus_per_worker=num_gpus_per_worker, @@ -120,6 +125,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): task_id) self._worker_device = "/job:%s/task:%d" % (task_type, task_id) + self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) if num_gpus_per_worker: local_devices = tuple( "%s/device:GPU:%d" % (self._worker_device, i) @@ -130,7 +136,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): self._collective_keys = cross_device_utils.CollectiveKeys() self._initialize_local(local_devices) - self._input_workers = values.InputWorkers( + self._input_workers = input_lib.InputWorkers( self._device_map, [(self._worker_device, self.worker_devices)]) self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( num_workers=self._num_workers, @@ -156,19 +162,23 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): if colocate_with is None: device_map = self._device_map logical_device = 0 # TODO(josh11b): Get logical device from scope here. + elif isinstance(colocate_with, numpy_dataset.SingleDevice): + with ops.device(colocate_with.device): + return next_creator(*args, **kwargs) else: device_map = colocate_with.device_map logical_device = colocate_with.logical_device - group_size = device_map.num_replicas_in_graph * self._num_workers - group_key = self._collective_keys.get_group_key(self.worker_devices) def _real_mirrored_creator(devices, *args, **kwargs): """Creates one MirroredVariable on the current worker.""" - value_list = [] unique_var_name = ops.get_default_graph().unique_name( kwargs["name"], mark_as_used=False).rstrip("/") + # pylint: disable=protected-access collective_instance_key = self._collective_keys.get_instance_key( key_id=unique_var_name) + # Only the first device participles in the broadcast of initial values. + group_key = self._collective_keys.get_group_key([devices[0]]) + group_size = self._num_workers if "initial_value" not in kwargs: raise ValueError("Initial value must be specified.") initial_value = kwargs["initial_value"] @@ -177,9 +187,33 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): else: initial_value_fn = lambda: initial_value + value_list = [] for i, d in enumerate(devices): - with ops.device(d): - if i > 0: + with ops.init_scope(), ops.device(d): + if i == 0: + # The initial value fn makes sure variables all initialized to + # same values. The first device of the chief worker will send their + # variable values to other workers. + def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring + with ops.device(device): + initial_value = initial_value_fn() + assert not callable(initial_value) + initial_value = ops.convert_to_tensor(initial_value) + + assert index == 0, index + if self._num_workers > 1: + if self._is_chief: + bcast_send = collective_ops.broadcast_send( + initial_value, initial_value.shape, initial_value.dtype, + group_size, group_key, collective_instance_key) + with ops.control_dependencies([bcast_send]): + return array_ops.identity(initial_value) + else: + return collective_ops.broadcast_recv( + initial_value.shape, initial_value.dtype, group_size, + group_key, collective_instance_key) + return initial_value + else: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to @@ -187,30 +221,22 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) - # The initial value fn makes sure variables all initialized to - # same values. The first device of the chief worker will send their - # variable values to other devices and other workers. - def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring - with ops.device(device): - initial_value = initial_value_fn() - assert not callable(initial_value) - initial_value = ops.convert_to_tensor(initial_value) - - if self._is_chief and index == 0: - bcast_send = collective_ops.broadcast_send( - initial_value, initial_value.shape, initial_value.dtype, - group_size, group_key, collective_instance_key) - with ops.control_dependencies([bcast_send]): - return array_ops.identity(initial_value) - else: - return collective_ops.broadcast_recv( - initial_value.shape, initial_value.dtype, group_size, - group_key, collective_instance_key) + # Variables on non-first replica get initial values from the + # variables created on the first device of each worker. + def _overridden_initial_value_fn(device=d, index=i): + assert index > 0 + with ops.device(device): + if context.executing_eagerly(): + return array_ops.identity(value_list[0].value()) + else: + return array_ops.identity(value_list[0].initial_value) kwargs["initial_value"] = _overridden_initial_value_fn - with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): - v = next_creator(*args, **kwargs) + # Don't record operations (e.g. other variable reads) during + # variable creation. + with tape.stop_recording(): + v = next_creator(*args, **kwargs) if i == 0: actual_var_name = v.name.split(":")[0] @@ -222,19 +248,12 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): # pylint: disable=protected-access return mirrored_strategy._create_mirrored_variable( - device_map, logical_device, _real_mirrored_creator, *args, **kwargs) - - def _distribute_dataset(self, dataset_fn): - """Distributes the dataset to each local GPU.""" - # TODO(yuefengz): shard the dataset. - worker_index = 0 - return values.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), self._input_workers, worker_index, - prefetch_on_device=True) + self._container_strategy(), device_map, logical_device, + _real_mirrored_creator, *args, **kwargs) def _make_dataset_iterator(self, dataset): - return values.DatasetIterator(dataset, self._input_workers, - self._num_replicas_in_sync) + return input_lib.DatasetIterator(dataset, self._input_workers, + self._num_replicas_in_sync) def _make_input_fn_iterator( self, @@ -251,7 +270,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): input_pipeline_id=input_pipeline_id, num_replicas_in_sync=self._num_replicas_in_sync) - return values.InputFunctionIterator( + return input_lib.InputFunctionIterator( input_fn, self._input_workers, [input_context]) def _configure(self, @@ -345,4 +364,11 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): - return False + """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. + + `make_input_fn_iterator` assumes per-replica batching. + + Returns: + Boolean. + """ + return True diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index 0fb672dded7624e798592d2f5c01945aa830021e..9b6236fd9f89ec30c1234c846930a05d9c32e99d 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -123,7 +123,7 @@ class CollectiveAllReduceStrategyTestBase( def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, args=[one]) + g_v = d.extended.call_for_each_replica(grad_fn, args=[one]) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] @@ -135,7 +135,7 @@ class CollectiveAllReduceStrategyTestBase( g = d.extended.reduce_to( reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( - d.update(v, update, g, grouped=False)): + d.extended.update(v, update, args=(g,), group=False)): after_list.append(d.extended.read_var(v)) return before_list, after_list @@ -192,6 +192,7 @@ class CollectiveAllReduceStrategyTestBase( image = random_ops.random_uniform([2, 28, 28]) label = random_ops.random_uniform([2, 1], maxval=10, dtype=dtypes.int32) logits = model(image, training=True) + # TODO(yuefengz): make loss a callable for eager mode. loss = losses.sparse_softmax_cross_entropy(labels=label, logits=logits) optimizer = adam.AdamOptimizer(learning_rate=1e-4) train_op = optimizer.minimize(loss, @@ -202,7 +203,7 @@ class CollectiveAllReduceStrategyTestBase( self.cached_session(config=config, target=master_target) as sess: with d.scope(): - train_op = d.call_for_each_replica(model_fn) + train_op = d.extended.call_for_each_replica(model_fn) train_op = d.group(d.unwrap(train_op)) sess.run(variables.global_variables_initializer()) @@ -225,7 +226,7 @@ class CollectiveAllReduceStrategyTestBase( 1.0, 10.0, dtype=dtypes.float32)) return array_ops.identity(x) - x = distribution.call_for_each_replica(model_fn) + x = distribution.extended.call_for_each_replica(model_fn) reduced_x = distribution.reduce(reduce_util.ReduceOp.MEAN, x) x = distribution.unwrap(x)[0] @@ -397,28 +398,38 @@ class DistributedCollectiveAllReduceStrategyTestWithChief( self._test_complex_model, self._cluster_spec, num_gpus=num_gpus) -class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase, - strategy_test_lib.DistributionTestBase, - parameterized.TestCase): +class LocalCollectiveAllReduceStrategy( + CollectiveAllReduceStrategyTestBase, + strategy_test_lib.DistributionTestBase, + strategy_test_lib.TwoDeviceDistributionTestBase, + parameterized.TestCase): - def testMinimizeLossGraph(self, num_gpus=2): + @combinations.generate( + combinations.combine( + mode=['graph', 'eager'], num_gpus=[2, 4], required_gpus=2)) + def testMinimizeLoss(self, num_gpus): # Collective ops doesn't support strategy with one device. if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') - self._test_minimize_loss_graph(None, None, num_gpus) + if context.executing_eagerly(): + strategy, _, _ = self._get_test_object(None, None, num_gpus) + self._test_minimize_loss_eager(strategy) + else: + self._test_minimize_loss_graph(None, None, num_gpus) - def testComplexModel(self, num_gpus=2): - # Collective ops doesn't support strategy with one device. + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[2, 4], required_gpus=2)) + def testComplexModel(self, num_gpus): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') self._test_complex_model(None, None, num_gpus) - def testMakeInputFnIterator(self, num_gpus=2): - # Collective ops doesn't support strategy with one device. - if context.num_gpus() < num_gpus: - self.skipTest('Not enough GPUs') - dataset_fn = lambda: dataset_ops.Dataset.range(10) - expected_values = [[i, i+1] for i in range(0, 10, 2)] + @combinations.generate( + combinations.combine(mode=['graph', 'eager'], required_gpus=2)) + def testMakeInputFnIterator(self): + num_gpus = 2 + dataset_fn = lambda: dataset_ops.Dataset.range(5 * num_gpus) + expected_values = [range(i, i + num_gpus) for i in range(0, 10, num_gpus)] input_fn = self._input_fn_to_test_input_context( dataset_fn, @@ -428,6 +439,49 @@ class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase, self._test_input_fn_iterator(None, None, num_gpus, input_fn, expected_values) + def testAllReduceSum(self): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object(None, None, num_gpus=2) + with self.cached_session(config=config, target=target): + self._test_all_reduce_sum(distribution) + + def testAllReduceSumGradients(self): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object(None, None, num_gpus=2) + with self.cached_session(config=config, target=target): + self._test_all_reduce_sum_gradients(distribution) + + def testAllReduceSumGradientTape(self): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object(None, None, num_gpus=2) + with self.cached_session(config=config, target=target): + self._test_all_reduce_sum_gradient_tape(distribution) + + def testAllReduceMean(self): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object(None, None, num_gpus=2) + with self.cached_session(config=config, target=target): + self._test_all_reduce_mean(distribution) + + def testAllReduceMeanGradients(self): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object(None, None, num_gpus=2) + with self.cached_session(config=config, target=target): + self._test_all_reduce_mean_gradients(distribution) + + def testAllReduceMeanGradientTape(self): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object(None, None, num_gpus=2) + with self.cached_session(config=config, target=target): + self._test_all_reduce_mean_gradient_tape(distribution) + + def testNumpyIterator(self): + num_gpus = 2 + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + strategy, _, _ = self._get_test_object(None, None, num_gpus) + self._test_numpy_iterator(strategy) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 4a934953ad2d4c6ecbe2bde2333a49bf8fd72821..db79d6c0d8ac4590b0e16598a3fde6f89d4759a6 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -46,16 +46,21 @@ import unittest from absl.testing import parameterized import six -from tensorflow.contrib.cluster_resolver import TPUClusterResolver +from tensorflow.contrib import cluster_resolver from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib from tensorflow.contrib.optimizer_v2 import adagrad as adagrad_v2 from tensorflow.contrib.optimizer_v2 import adam as adam_v2 from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 +from tensorflow.contrib.tpu.python.tpu import device_assignment as device_assignment_lib from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import context from tensorflow.python.framework import ops +from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_keras_v2 +from tensorflow.python.keras.optimizer_v2 import adam as adam_keras_v2 +from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras_v2 +from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_keras_v2 from tensorflow.python.training import adagrad from tensorflow.python.training import adam from tensorflow.python.training import gradient_descent @@ -321,6 +326,23 @@ class NamedDistribution(object): return self._required_tpu +def _get_tpu_strategy_creator(steps_per_run, use_single_core=False, **kwargs): + def _create_tpu_strategy(): + resolver = cluster_resolver.TPUClusterResolver("") + topology = tpu_lib.initialize_tpu_system(resolver) + device_assignment = None + if use_single_core: + device_assignment = device_assignment_lib.DeviceAssignment( + topology, core_assignment=device_assignment_lib. + SINGLE_CORE_ASSIGNMENT) + + strategy = tpu_lib.TPUStrategy(resolver, steps_per_run=steps_per_run, + device_assignment=device_assignment, + **kwargs) + return strategy + return _create_tpu_strategy + + # pylint: disable=g-long-lambda default_strategy = NamedDistribution( "Default", @@ -330,13 +352,31 @@ one_device_strategy = NamedDistribution( "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), required_gpus=None) tpu_strategy = NamedDistribution( - "TPU", lambda: tpu_lib.TPUStrategy( - TPUClusterResolver(""), steps_per_run=2), + "TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True) tpu_strategy_one_step = NamedDistribution( - "TPUOneStep", lambda: tpu_lib.TPUStrategy( - TPUClusterResolver(""), steps_per_run=1), + "TPUOneStep", _get_tpu_strategy_creator(steps_per_run=1), + required_tpu=True) +tpu_strategy_loop_on_device_one_core = NamedDistribution( + "TPULoopOnDeviceOneCore", _get_tpu_strategy_creator( + steps_per_run=2, use_single_core=True, + _disable_training_loop_on_host=True), + required_tpu=True) +tpu_strategy_one_step_loop_on_device_one_core = NamedDistribution( + "TPUOneStepLoopOnDeviceOneCore", _get_tpu_strategy_creator( + steps_per_run=1, use_single_core=True, + _disable_training_loop_on_host=True), required_tpu=True) +# TODO(b/122327153): Remove below two NamedDistributions. +tpu_strategy_loop_on_device = NamedDistribution( + "TPULoopOnDevice", _get_tpu_strategy_creator( + steps_per_run=2, _disable_training_loop_on_host=True), + required_tpu=True) +tpu_strategy_one_step_loop_on_device = NamedDistribution( + "TPUOneStepLoopOnDevice", _get_tpu_strategy_creator( + steps_per_run=1, _disable_training_loop_on_host=True), + required_tpu=True) + mirrored_strategy_with_one_cpu = NamedDistribution( "Mirrored1CPU", lambda: mirrored_lib.MirroredStrategy(["/cpu:0"])) @@ -386,10 +426,20 @@ gradient_descent_optimizer_v2_fn = NamedObject( adagrad_optimizer_v2_fn = NamedObject( "AdagradV2", lambda: adagrad_v2.AdagradOptimizer(0.001)) adam_optimizer_v2_fn = NamedObject( - "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1)) + "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1.0)) optimizers_v2 = [gradient_descent_optimizer_v2_fn, adagrad_optimizer_v2_fn] +gradient_descent_optimizer_keras_v2_fn = NamedObject( + "GradientDescentKerasV2", + lambda: gradient_descent_keras_v2.SGD(0.2)) +adagrad_optimizer_keras_v2_fn = NamedObject( + "AdagradKerasV2", lambda: adagrad_keras_v2.Adagrad(0.001)) +adam_optimizer_keras_v2_fn = NamedObject( + "AdamKerasV2", lambda: adam_keras_v2.Adam(0.001, epsilon=1.0)) +rmsprop_optimizer_keras_v2_fn = NamedObject( + "RmsPropKerasV2", lambda: rmsprop_keras_v2.RMSprop(0.001)) + graph_and_eager_modes = ["graph", "eager"] diff --git a/tensorflow/contrib/distribute/python/examples/BUILD b/tensorflow/contrib/distribute/python/examples/BUILD index 84b106545e1326fddd3ed299462534af982dc102..5f89df5824a8d03198987a6fa3d21e2330deedf0 100644 --- a/tensorflow/contrib/distribute/python/examples/BUILD +++ b/tensorflow/contrib/distribute/python/examples/BUILD @@ -31,6 +31,12 @@ py_binary( py_binary( name = "keras_mnist", + srcs = ["keras_mnist.py"], + deps = [":keras_mnist_lib"], +) + +py_library( + name = "keras_mnist_lib", srcs = [ "keras_mnist.py", ], @@ -39,3 +45,14 @@ py_binary( "//third_party/py/numpy", ], ) + +py_binary( + name = "mnist_eager_multigpu", + srcs = [ + "mnist_eager_multigpu.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index 60fda996642464135fe1fb8c314bcf7f04d19362..1ce91ecaf22a80a53124c8f00fac05c6b4711ed9 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -109,22 +109,21 @@ def main(_): tf.enable_eager_execution() train_ds, eval_ds, input_shape = get_input_datasets() - model = get_model(input_shape) # Instantiate the MirroredStrategy object. If we don't specify `num_gpus` or # the `devices` argument then all the GPUs available on the machine are used. # TODO(priyag): Use `tf.distribute.MirroredStrategy` once available. strategy = mirrored_strategy.MirroredStrategy(['/gpu:0', '/cpu:0']) - optimizer = rmsprop.RMSProp(learning_rate=0.001) - - # Compile the model by passing the distribution strategy object to the - # `distribute` argument. `fit`, `evaluate` and `predict` will be distributed - # based on the strategy instantiated. - model.compile(loss=tf.keras.losses.categorical_crossentropy, - optimizer=optimizer, - metrics=['accuracy'], - distribute=strategy) + # Create and compile the model under Distribution strategy scope. + # `fit`, `evaluate` and `predict` will be distributed based on the strategy + # model was compiled with. + with strategy.scope(): + model = get_model(input_shape) + optimizer = rmsprop.RMSProp(learning_rate=0.001) + model.compile(loss=tf.keras.losses.categorical_crossentropy, + optimizer=optimizer, + metrics=['accuracy']) # Train the model with the train dataset. model.fit(x=train_ds, epochs=20, steps_per_epoch=468) diff --git a/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py b/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py new file mode 100644 index 0000000000000000000000000000000000000000..11a3b5e91a52a6881d48a0aceadbd901a46e86b2 --- /dev/null +++ b/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py @@ -0,0 +1,151 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Run MNIST on multiple GPUs on using MirroredStrategy with eager execution. + +By default, runs on all available GPUs, or CPU if no GPUs are available. + +NOTE: Currently, this takes more time than when running MNIST in eager without +MirroredStrategy because of a number overheads. Therefore, this is just a +proof of concept right now and cannot be used to actually scale up training. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +tf.flags.DEFINE_integer("num_gpus", None, "How many GPUs should we run on?" + "Defaults to all available GPUs, otherwise CPU.") +tf.flags.DEFINE_integer("batch_size", 64, + "What should be the size of each batch?") +tf.flags.DEFINE_integer("num_epochs", 10, "How many epochs to run?") +tf.flags.DEFINE_float("learning_rate", 0.01, "Learning Rate") +tf.flags.DEFINE_float("momentum", 0.5, "SGD momentum") + +FLAGS = tf.flags.FLAGS +NUM_TRAIN_IMAGES = 60000 +NUM_TEST_IMAGES = 10000 + + +def create_model(): + max_pool = tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding="same") + # The model consists of a sequential chain of layers, so tf.keras.Sequential + # (a subclass of tf.keras.Model) makes for a compact description. + return tf.keras.Sequential([ + tf.keras.layers.Reshape( + target_shape=[28, 28, 1], + input_shape=(28, 28,)), + tf.keras.layers.Conv2D(2, 5, padding="same", activation=tf.nn.relu), + max_pool, + tf.keras.layers.Conv2D(4, 5, padding="same", activation=tf.nn.relu), + max_pool, + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(32, activation=tf.nn.relu), + tf.keras.layers.Dropout(0.4), + tf.keras.layers.Dense(10)]) + + +def compute_loss(logits, labels): + loss = tf.reduce_sum( + tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels)) + # Scale loss by global batch size. + return loss * (1. / FLAGS.batch_size) + + +def mnist_datasets(): + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() + # Numpy defaults to dtype=float64; TF defaults to float32. Stick with float32. + x_train, x_test = x_train / np.float32(255), x_test / np.float32(255) + y_train, y_test = y_train.astype(np.int64), y_test.astype(np.int64) + # TODO(priyag): `strategy.make_numpy_iterator` can be used directly instead of + # converting to datasets. + train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) + test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) + return train_dataset, test_dataset + + +def main(unused_argv): + """Run a CNN model on MNIST data to demonstrate DistributedStrategies.""" + + tf.enable_eager_execution() + + num_gpus = FLAGS.num_gpus + if num_gpus is None: + devices = None + elif num_gpus == 0: + devices = ["/device:CPU:0"] + else: + devices = ["/device:GPU:{}".format(i) for i in range(num_gpus)] + strategy = tf.distribute.MirroredStrategy(devices) + + with strategy.scope(): + train_ds, test_ds = mnist_datasets() + train_ds = train_ds.shuffle(NUM_TRAIN_IMAGES).batch(FLAGS.batch_size) + test_ds = test_ds.batch(FLAGS.batch_size) + + model = create_model() + optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum) + training_loss = tf.keras.metrics.Mean("training_loss", dtype=tf.float32) + training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( + "training_accuracy", dtype=tf.float32) + test_loss = tf.keras.metrics.Mean("test_loss", dtype=tf.float32) + test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( + "test_accuracy", dtype=tf.float32) + + def train_step(inputs): + images, labels = inputs + with tf.GradientTape() as tape: + logits = model(images, training=True) + loss = compute_loss(logits, labels) + grads = tape.gradient(loss, model.variables) + optimizer.apply_gradients(zip(grads, model.variables)) + training_loss.update_state(loss) + training_accuracy.update_state(labels, logits) + + def test_step(inputs): + images, labels = inputs + logits = model(images, training=False) + loss = compute_loss(logits, labels) + test_loss.update_state(loss) + test_accuracy.update_state(labels, logits) + + train_iterator = strategy.make_dataset_iterator(train_ds) + test_iterator = strategy.make_dataset_iterator(test_ds) + for epoch in range(0, FLAGS.num_epochs): + # Train + print("Starting epoch {}".format(epoch)) + train_iterator.initialize() + for _ in range(NUM_TRAIN_IMAGES // FLAGS.batch_size): + strategy.experimental_run(train_step, train_iterator) + print("Training loss: {:0.4f}, accuracy: {:0.2f}%".format( + training_loss.result(), training_accuracy.result() * 100)) + training_loss.reset_states() + training_accuracy.reset_states() + + # Test + test_iterator.initialize() + for _ in range(NUM_TEST_IMAGES // FLAGS.batch_size): + strategy.experimental_run(test_step, test_iterator) + print("Test loss: {:0.4f}, accuracy: {:0.2f}%".format( + test_loss.result(), test_accuracy.result() * 100)) + test_loss.reset_states() + test_accuracy.reset_states() + + +if __name__ == "__main__": + tf.app.run() diff --git a/tensorflow/contrib/distribute/python/input_lib_test.py b/tensorflow/contrib/distribute/python/input_lib_test.py new file mode 100644 index 0000000000000000000000000000000000000000..10a58316ec5b3d9d968a88c5c39ff70c277daa65 --- /dev/null +++ b/tensorflow/contrib/distribute/python/input_lib_test.py @@ -0,0 +1,246 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the input_lib library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib +from tensorflow.python.distribute import values +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import errors +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.util import nest + + +class InputIteratorTestBase(test.TestCase): + + def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, + expected_values, sess=None, split_batch_by=None): + devices = nest.flatten([ds for _, ds in worker_device_pairs]) + device_map = values.ReplicaDeviceMap(devices) + input_workers = input_lib.InputWorkers(device_map, worker_device_pairs) + + if input_type == "input_fn": + input_contexts = [ + distribute_lib.InputContext() for _ in worker_device_pairs] + input_fn = lambda _: dataset_fn() + iterator = input_lib.InputFunctionIterator( + input_fn, input_workers, input_contexts) + else: + iterator = input_lib.DatasetIterator( + dataset_fn(), input_workers, split_batch_by) + + evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) + + evaluate(control_flow_ops.group(iterator.initialize())) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_replica(r, next_element) for r in range(len(devices))]) + self.assertAllEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + evaluate([values.select_replica(r, next_element) + for r in range(len(devices))]) + + # After re-initializing the iterator, should be able to iterate again. + evaluate(control_flow_ops.group(iterator.initialize())) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_replica(r, next_element) for r in range(len(devices))]) + self.assertAllEqual(expected_value, computed_value) + + +class InputIteratorSingleWorkerTest(InputIteratorTestBase, + parameterized.TestCase): + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"])) + def testOneDeviceCPU(self, input_type): + worker_device_pairs = [("", ["/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(10) + + expected_values = [[i] for i in range(10)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTwoDevicesOneGPUOneCPU(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(10) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTupleDataset(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testUnevenDatasetBatches(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(11) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["dataset"], + split_batch_by=[None, 2], + required_gpus=1)) + def testBatchSplitting(self, input_type, split_batch_by): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + batch_size = 10 + dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size) + + updated_batch_size = ( + batch_size // split_batch_by if split_batch_by else batch_size) + expected_values = [[range(i, i+updated_batch_size), + range(i+updated_batch_size, i+2*updated_batch_size)] + for i in range(0, 100, updated_batch_size*2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values, sess=None, + split_batch_by=split_batch_by) + + +class InputIteratorMultiWorkerTest( + multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase, + parameterized.TestCase): + + def _cpu_devices(self): + return [ + ("/job:worker/replica:0/task:0", + ["/job:worker/replica:0/task:0/device:CPU:0"]), + ("/job:worker/replica:0/task:1", + ["/job:worker/replica:0/task:1/device:CPU:0"])] + + def _cpu_and_one_gpu_devices(self): + return [ + ("/job:worker/replica:0/task:0", [ + "/job:worker/replica:0/task:0/device:GPU:0", + "/job:worker/replica:0/task:0/device:CPU:0" + ]), + ("/job:worker/replica:0/task:1", [ + "/job:worker/replica:0/task:1/device:GPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ]) + ] + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"])) + def testOneDevicePerWorker(self, input_type): + worker_devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_iterator(input_type, dataset_fn, worker_devices, + [[0, 0], [1, 1], [2, 2], [3, 3]], sess) + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTwoDevicesPerWorker(self, input_type): + worker_devices = self._cpu_and_one_gpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_iterator(input_type, dataset_fn, worker_devices, + [[0, 1, 0, 1], [2, 3, 2, 3]], sess) + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"])) + def testTupleDataset(self, input_type): + worker_devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(4) + dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)] + self._test_iterator(input_type, dataset_fn, worker_devices, + expected_values, sess) + + +class SplitDatasetBatchTest(test.TestCase): + + def testBatchDataset(self): + dataset = dataset_ops.Dataset.range(100).batch(20) + split_batch_by = 2 + result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) + expected_values = [range(i, i+10) for i in range(0, 100, 10)] + result = [self.evaluate(el) for el in result_dataset] + self.assertAllEqual(expected_values, result) + + def testMapAndBatchDataset(self): + dataset = dataset_ops.Dataset.range(100) + dataset = dataset.apply(batching.map_and_batch(lambda x: x, 20)) + split_batch_by = 2 + result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) + expected_values = [range(i, i+10) for i in range(0, 100, 10)] + result = [self.evaluate(el) for el in result_dataset] + self.assertAllEqual(expected_values, result) + + def testPrefetchDataset(self): + dataset = dataset_ops.Dataset.range(100).batch(20).prefetch(1) + split_batch_by = 2 + result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) + expected_values = [range(i, i+10) for i in range(0, 100, 10)] + result = [self.evaluate(el) for el in result_dataset] + self.assertAllEqual(expected_values, result) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_backward_compat_test.py b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py index 93c0280c8215712071457cafb9c6040f7d97fa60..3bc84dc009bf91493d10d28ef7c3b718ef17ba91 100644 --- a/tensorflow/contrib/distribute/python/keras_backward_compat_test.py +++ b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os from absl.testing import parameterized import numpy as np @@ -27,20 +26,12 @@ from tensorflow.contrib.distribute.python import tpu_strategy from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribute_lib -from tensorflow.python.distribute import values from tensorflow.python.eager import test -from tensorflow.python.estimator import keras as keras_lib -from tensorflow.python.estimator import run_config as run_config_lib -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed -from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import distributed_training_utils from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras from tensorflow.python.ops.parsing_ops import gen_parsing_ops -from tensorflow.python.platform import gfile -from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import gradient_descent from tensorflow.python.training import rmsprop @@ -325,15 +316,20 @@ def all_strategy_combinations(): return strategy_minus_tpu_combinations() + tpu_strategy_combinations() -# TODO(priyag): Add v2 optimizers here. def strategy_and_optimizer_combinations(): + # TODO(b/122372746): Uncomment optimizers after they pass tests. return combinations.times( all_strategy_combinations(), - combinations.combine( - optimizer=[combinations.adagrad_optimizer_v1_fn, - combinations.adam_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v1_fn, - combinations.rmsprop_optimizer_v1_fn])) + combinations.combine(optimizer=[ + combinations.adagrad_optimizer_v1_fn, + # combinations.adagrad_optimizer_keras_v2_fn, + combinations.adam_optimizer_v1_fn, + combinations.adam_optimizer_keras_v2_fn, + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_keras_v2_fn, + combinations.rmsprop_optimizer_v1_fn, + # combinations.rmsprop_optimizer_keras_v2_fn + ])) def strategy_and_input_combinations(): @@ -359,298 +355,9 @@ def strategy_for_numpy_input_combinations(): mode=['graph']) -class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, - parameterized.TestCase): - - def setUp(self): - self._base_dir = os.path.join(self.get_temp_dir(), - 'keras_mirrored_strategy_test') - gfile.MakeDirs(self._base_dir) - self._config = run_config_lib.RunConfig( - tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir) - - def tearDown(self): - writer_cache.FileWriterCache.clear() - if os.path.isdir(self._base_dir): - gfile.DeleteRecursively(self._base_dir) - - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_two_gpus], - mode=['graph'])) - def test_train_functional_with_distribution_strategy(self, distribution): - keras_model = simple_functional_model() - keras_model.compile( - loss='categorical_crossentropy', - metrics=[keras.metrics.CategoricalAccuracy()], - optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) - config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, - model_dir=self._base_dir, - train_distribute=distribution, - eval_distribute=distribution) - with self.cached_session(): - est_keras = keras_lib.model_to_estimator( - keras_model=keras_model, config=config) - before_eval_results = est_keras.evaluate( - input_fn=get_ds_test_input_fn, steps=1) - est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) - after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn, - steps=1) - self.assertLess(after_eval_results['loss'], before_eval_results['loss']) - - writer_cache.FileWriterCache.clear() - gfile.DeleteRecursively(self._config.model_dir) - - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_two_gpus], - mode=['graph'])) - def test_train_sequential_with_distribution_strategy(self, distribution): - keras_model = simple_sequential_model() - keras_model.compile( - loss='categorical_crossentropy', - metrics=[keras.metrics.CategoricalAccuracy()], - optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) - config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, - model_dir=self._base_dir, - train_distribute=distribution) - with self.cached_session(): - est_keras = keras_lib.model_to_estimator( - keras_model=keras_model, config=config) - before_eval_results = est_keras.evaluate( - input_fn=get_ds_test_input_fn, steps=1) - est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) - after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn, - steps=1) - self.assertLess(after_eval_results['loss'], before_eval_results['loss']) - - writer_cache.FileWriterCache.clear() - gfile.DeleteRecursively(self._config.model_dir) - - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph'])) - def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self, distribution): - train_data, test_data = get_multi_inputs_multi_outputs_data() - - def train_input_fn(): - input_dict = { - 'input_a': train_data['input_a'], - 'input_b': train_data['input_b'], - 'input_m': train_data['input_m'].astype(np.str) - } - output_dict = { - 'dense_2': train_data['output_c'], - 'dense_3': train_data['output_d'] - } - return dataset_ops.Dataset.from_tensor_slices((input_dict, - output_dict)).batch(16) - - def eval_input_fn(): - input_dict = { - 'input_a': test_data['input_a'], - 'input_b': test_data['input_b'], - 'input_m': test_data['input_m'].astype(np.str) - } - output_dict = { - 'dense_2': test_data['output_c'], - 'dense_3': test_data['output_d'] - } - return dataset_ops.Dataset.from_tensor_slices((input_dict, - output_dict)).batch(16) - - self.do_test_multi_inputs_multi_outputs_with_input_fn( - distribution, train_input_fn, eval_input_fn) - - def do_test_multi_inputs_multi_outputs_with_input_fn( - self, distribution, train_input_fn, eval_input_fn): - config = run_config_lib.RunConfig( - tf_random_seed=_RANDOM_SEED, - model_dir=self._base_dir, - train_distribute=distribution) - with self.cached_session(): - model = multi_inputs_multi_outputs_model() - est_keras = keras_lib.model_to_estimator(keras_model=model, config=config) - baseline_eval_results = est_keras.evaluate( - input_fn=eval_input_fn, steps=1) - est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16) - eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1) - self.assertLess(eval_results['loss'], baseline_eval_results['loss']) - - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph'])) - def test_keras_optimizer_with_distribution_strategy(self, distribution): - keras_model = simple_sequential_model() - keras_model.compile( - loss='categorical_crossentropy', - optimizer=keras.optimizers.rmsprop(lr=0.01)) - - config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, - model_dir=self._base_dir, - train_distribute=distribution) - with self.cached_session(): - est_keras = keras_lib.model_to_estimator(keras_model=keras_model, - config=config) - with self.assertRaisesRegexp(ValueError, - 'Only TensorFlow native optimizers are ' - 'supported with DistributionStrategy.'): - est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) - - writer_cache.FileWriterCache.clear() - gfile.DeleteRecursively(self._config.model_dir) - - class TestDistributionStrategyWithNumpyArrays(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_for_numpy_input_combinations()) - def test_creating_var_with_numpy_arrays(self, distribution): - with self.cached_session(): - x = np.asarray(np.random.random((64, 3)), dtype=np.float32) - var_x = distributed_training_utils.get_var_for_numpy(distribution, x) - val = self.evaluate(var_x.value()) - # Verify that the numpy value is copied to the variable. - self.assertAllEqual(x, val) - - @combinations.generate(strategy_for_numpy_input_combinations()) - def test_calculating_input_params_no_steps_no_batch_size(self, distribution): - # Calculate the per_replica_batch_size scaling factor for strategies - # that use per_core_batch_size - replica_scale_factor = 1.0 - if not distributed_training_utils.global_batch_size_supported(distribution): - replica_scale_factor = distribution.num_replicas_in_sync - - with self.cached_session(): - # Input samples of different sizes - input_20_samples = np.zeros((20, 3), dtype=np.float32) - input_63_samples = np.zeros((63, 3), dtype=np.float32) - input_64_samples = np.zeros((64, 3), dtype=np.float32) - - # Default global batch size 32 for input with 64 samples run in 2 steps - steps, batch_size = distributed_training_utils.get_input_params( - distribution, input_64_samples, steps=None, batch_size=None) - self.assertEqual(batch_size, 32 // replica_scale_factor) - self.assertEqual(steps, 2) - - # Computed global batch size 20 is lower than 32 if we pass less samples. - steps, batch_size = distributed_training_utils.get_input_params( - distribution, input_20_samples, steps=None, batch_size=None) - self.assertEqual(batch_size, 20 // replica_scale_factor) - self.assertEqual(steps, 1) - - # Default global batch size 32 cannot be used with 63 samples. - with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'): - distributed_training_utils.get_input_params( - distribution, input_63_samples, steps=None, batch_size=None) - - @combinations.generate(strategy_for_numpy_input_combinations()) - def test_calculating_input_params_with_steps_no_batch_size(self, - distribution): - # Calculate the per_replica_batch_size scaling factor for strategies - # that use per_core_batch_size - replica_scale_factor = 1.0 - if not distributed_training_utils.global_batch_size_supported(distribution): - replica_scale_factor = distribution.num_replicas_in_sync - - with self.cached_session(): - # Input samples of different sizes - input_63_samples = np.zeros((63, 3), dtype=np.float32) - input_64_samples = np.zeros((64, 3), dtype=np.float32) - - # Computed global batch size is correct for number of specified 1 step - steps, batch_size = distributed_training_utils.get_input_params( - distribution, input_64_samples, steps=1, batch_size=None) - self.assertEqual(batch_size, 64 // replica_scale_factor) - self.assertEqual(steps, 1) - - # Computed global batch size is correct for number of specified 2 steps - steps, batch_size = distributed_training_utils.get_input_params( - distribution, input_64_samples, steps=2, batch_size=None) - self.assertEqual(batch_size, 32 // replica_scale_factor) - self.assertEqual(steps, 2) - - # All samples can not be consumed in specified number of steps - with self.assertRaisesRegexp(ValueError, 'not divisible by steps'): - distributed_training_utils.get_input_params( - distribution, input_63_samples, steps=2, batch_size=None) - - # This cases is different for different strategies due to the - # difference in supported batch size being global or per-replica. - if replica_scale_factor == 1: - # Computed global batch size is correct even if not sharadable - steps, batch_size = distributed_training_utils.get_input_params( - distribution, input_63_samples, steps=3, batch_size=None) - self.assertEqual(batch_size, 21) - self.assertEqual(steps, 3) - else: - # Computed global batch size can not be sharded across replicas - with self.assertRaisesRegexp(ValueError, 'could not be sharded evenly ' - 'across the sync replicas'): - distributed_training_utils.get_input_params( - distribution, input_63_samples, steps=1, batch_size=None) - - @combinations.generate(strategy_for_numpy_input_combinations()) - def test_calculating_input_params_no_steps_with_batch_size(self, - distribution): - # Calculate the per_replica_batch_size scaling factor for strategies - # that use per_core_batch_size - replica_scale_factor = 1.0 - if not distributed_training_utils.global_batch_size_supported(distribution): - replica_scale_factor = distribution.num_replicas_in_sync - - with self.cached_session(): - input_64_samples = np.zeros((64, 3), dtype=np.float32) - - # Computed steps is correct for specified batch size - steps, batch_size = distributed_training_utils.get_input_params( - distribution, input_64_samples, steps=None, batch_size=16) - self.assertEqual(batch_size, 16) - self.assertEqual(steps, 4 // replica_scale_factor) - - # Computed steps is correct for specified batch size - steps, batch_size = distributed_training_utils.get_input_params( - distribution, input_64_samples, steps=None, batch_size=32) - self.assertEqual(batch_size, 32) - self.assertEqual(steps, 2 // replica_scale_factor) - - # Number of samples is not divisible by the global batch size - with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'): - distributed_training_utils.get_input_params( - distribution, input_64_samples, steps=None, batch_size=20) - - # Number of samples is not divisible by the global batch size - with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'): - distributed_training_utils.get_input_params( - distribution, input_64_samples, steps=None, batch_size=3) - - @combinations.generate(strategy_for_numpy_input_combinations()) - def test_calculating_input_params_with_steps_with_batch_size(self, - distribution): - with self.cached_session(): - input_64_samples = np.zeros((64, 3), dtype=np.float32) - - # No change to steps and batch size if both specified and feasible - steps, batch_size = distributed_training_utils.get_input_params( - distribution, input_64_samples, steps=5, batch_size=3) - self.assertEqual(batch_size, 3) - self.assertEqual(steps, 5) - - # Number of samples is less than global batch size * steps - with self.assertRaisesRegexp(ValueError, 'less than samples required'): - distributed_training_utils.get_input_params( - distribution, input_64_samples, steps=10, batch_size=13) - @combinations.generate(strategy_for_numpy_input_combinations()) def test_calling_model_with_numpy_arrays(self, distribution): with self.cached_session(): @@ -1039,7 +746,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) - grouped_models = distribution.unwrap(model._distributed_model) + grouped_models = distribution.unwrap(model._distributed_model_train) with distribution.scope(): for m in grouped_models: self.assertAllClose(0.001, keras.backend.get_value( @@ -1048,54 +755,6 @@ class TestDistributionStrategyWithDatasets(test.TestCase, class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) - def test_validating_dataset_input_tensors_with_shape_mismatch(self, - distribution): - with self.cached_session(): - a = constant_op.constant([1, 2], shape=(1, 2)) - b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) - device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) - x = values.DistributedValues(device_map, (a, b)) - y = values.DistributedValues(device_map, (a, a)) - with distribution.scope(): - # Removed device and input tensor shape details from the error message - # since the order of the device and the corresponding input tensor shape - # is not deterministic over different runs. - with self.assertRaisesRegexp(ValueError, - 'Input tensor shapes do not match for ' - 'distributed tensor inputs ' - 'DistributedValues:.+'): - distributed_training_utils.validate_distributed_dataset_inputs( - distribution, x, y) - - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) - def test_validating_dataset_input_tensors_with_dtype_mismatch(self, - distribution): - with self.cached_session(): - a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) - b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) - device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) - x = values.DistributedValues(device_map, (a, b)) - y = values.DistributedValues(device_map, (a, a)) - with distribution.scope(): - # Removed device and input tensor dtype details from the error message - # since the order of the device and the corresponding input tensor dtype - # is not deterministic over different runs. - with self.assertRaisesRegexp(ValueError, - 'Input tensor dtypes do not match for ' - 'distributed tensor inputs ' - 'DistributedValues:.+'): - distributed_training_utils.validate_distributed_dataset_inputs( - distribution, x, y) - @combinations.generate(combinations.combine( distribution=[ combinations.mirrored_strategy_with_gpu_and_cpu, @@ -1135,14 +794,14 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): # Test with not specifying the `steps` argument. with self.assertRaisesRegexp( - ValueError, 'you should specify the `steps_per_epoch` argument'): + ValueError, 'the `steps_per_epoch` argument'): model.fit(dataset, epochs=1, verbose=0) with self.assertRaisesRegexp(ValueError, - 'you should specify the `steps` argument'): + 'the `steps` argument'): model.evaluate(dataset, verbose=0) with self.assertRaisesRegexp(ValueError, - 'you should specify the `steps` argument'): + 'the `steps` argument'): model.predict(dataset, verbose=0) @combinations.generate(combinations.combine( diff --git a/tensorflow/contrib/distribute/python/keras_correctness_test.py b/tensorflow/contrib/distribute/python/keras_correctness_test.py deleted file mode 100644 index e078731610882bfe6d5a97b1636d9a4a1325b047..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distribute/python/keras_correctness_test.py +++ /dev/null @@ -1,362 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Correctness tests for tf.keras using DistributionStrategy.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl.testing import parameterized -import numpy as np - -from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import tpu_strategy -from tensorflow.python import keras -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.distribute import distribute_lib -from tensorflow.python.eager import test -from tensorflow.python.framework import random_seed -from tensorflow.python.keras.engine import distributed_training_utils -from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras -from tensorflow.python.training import gradient_descent - -_RANDOM_SEED = 1337 - -# Note: Please make sure the tests in this file are also covered in -# keras_backward_compat_test for features that are supported with both APIs. - - -def batch_wrapper(dataset, batch_size, distribution, repeat=None): - if repeat: - dataset = dataset.repeat(repeat) - # TPUs currently require fully defined input shapes, drop_remainder ensures - # the input will have fully defined shapes. - if isinstance(distribution, tpu_strategy.TPUStrategy): - return dataset.batch(batch_size, drop_remainder=True) - else: - return dataset.batch(batch_size) - - -def get_correctness_test_inputs(use_numpy, use_validation_data, - with_distribution, - x_train, y_train, x_predict): - """Generates the inputs for correctness check when enable Keras with DS.""" - training_epochs = 2 - global_batch_size = 64 - batch_size = global_batch_size - # TODO(b/118776054): Use global batch size for Keras/DS support. - use_per_core_batch_size = ( - with_distribution and - not distributed_training_utils.global_batch_size_supported( - with_distribution)) - if use_per_core_batch_size: - batch_size //= with_distribution.num_replicas_in_sync - - if use_numpy: - training_inputs = { - 'batch_size': batch_size, - 'x': x_train, - 'y': y_train, - 'epochs': training_epochs, - 'shuffle': False, - } - - if use_validation_data: - eval_inputs = None - training_inputs['validation_data'] = (x_train, y_train) - else: - eval_inputs = { - 'batch_size': batch_size, - 'x': x_train, - 'y': y_train, - } - predict_inputs = { - 'x': np.array(x_predict, dtype=np.float32), - } - else: - # For dataset inputs, we do not pass batch_size to - # keras.fit/evaluate/predict. The batch size is part of the dataset. - train_dataset = dataset_ops.Dataset.from_tensor_slices( - (x_train, y_train)) - x = batch_wrapper( - train_dataset, batch_size, with_distribution, repeat=training_epochs) - - training_inputs = { - 'batch_size': None, - 'x': x, - 'y': None, - 'epochs': training_epochs, - 'shuffle': False, - 'steps_per_epoch': len(x_train) // global_batch_size, - } - if use_validation_data: - eval_inputs = None # Remove the eval_inputs - eval_dataset = dataset_ops.Dataset.from_tensor_slices( - (x_train, y_train)) - x = batch_wrapper(eval_dataset, batch_size, with_distribution) - training_inputs['validation_data'] = x - training_inputs['validation_steps'] = 5 - else: - eval_inputs = { - 'batch_size': None, - 'x': x, - 'y': None, - 'steps': 20, - } - - predict_batch_size = len(x_predict) - if use_per_core_batch_size: - predict_batch_size //= with_distribution.num_replicas_in_sync - predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) - predict_dataset = batch_wrapper(predict_dataset, - predict_batch_size, with_distribution) - predict_inputs = { - 'steps': 1, - 'x': predict_dataset, - } - - return training_inputs, eval_inputs, predict_inputs - - -strategies_minus_tpu = [ - combinations.default_strategy, - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_two_gpus] - -tpu_strategies = [ - combinations.tpu_strategy, # steps_per_run=2 - combinations.tpu_strategy_one_step] - - -def strategy_minus_tpu_combinations(): - return combinations.combine( - distribution=strategies_minus_tpu, - mode=['graph', 'eager']) - - -def tpu_strategy_combinations(): - return combinations.combine( - distribution=tpu_strategies, - mode=['graph']) - - -def all_strategy_combinations(): - return strategy_minus_tpu_combinations() + tpu_strategy_combinations() - - -def strategy_and_input_combinations(): - return ( - combinations.times( - combinations.combine(distribution=strategies_minus_tpu), - combinations.combine(mode=['graph'], - use_numpy=[True, False], - use_validation_data=[True, False]) - + combinations.combine(mode=['eager'], - use_numpy=[False], - use_validation_data=[False])) + - combinations.times( - combinations.combine(distribution=tpu_strategies), - combinations.combine(mode=['graph'], - use_numpy=[True, False], - use_validation_data=[True, False]))) - - -class TestDistributionStrategyCorrectness(test.TestCase, - parameterized.TestCase): - - @combinations.generate(all_strategy_combinations()) - def test_metric_correctness(self, distribution): - with self.cached_session(): - keras.backend.set_image_data_format('channels_last') - num_samples = 10000 - - x_train = np.random.randint(0, 2, num_samples) - x_train = np.reshape(x_train, (num_samples, 1)) - y_train = x_train - x_train = x_train.astype('float32') - y_train = y_train.astype('float32') - - # Create identity model. - with distribution.scope(): - model = keras.Sequential() - model.add( - keras.layers.Dense(1, input_shape=(1,), kernel_initializer='ones')) - model.compile( - loss=keras.losses.mean_squared_error, - optimizer=gradient_descent.GradientDescentOptimizer(0.5), - metrics=[keras.metrics.BinaryAccuracy()]) - - batch_size = 64 - if not distributed_training_utils.global_batch_size_supported( - distribution): - batch_size //= distribution.num_replicas_in_sync - train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) - train_dataset = batch_wrapper(train_dataset, batch_size, distribution) - - history = model.fit(x=train_dataset, epochs=2, steps_per_epoch=10) - self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0]) - - @combinations.generate(all_strategy_combinations()) - def test_eval_metrics_correctness(self, distribution): - with self.cached_session(): - with distribution.scope(): - model = keras.Sequential() - model.add( - keras.layers.Dense( - 3, activation='relu', input_dim=4, kernel_initializer='ones')) - model.add( - keras.layers.Dense( - 1, activation='sigmoid', kernel_initializer='ones')) - model.compile( - loss='mae', - metrics=['accuracy', keras.metrics.BinaryAccuracy()], - optimizer=gradient_descent.GradientDescentOptimizer(0.001)) - - # verify correctness of stateful and stateless metrics. - x = np.ones((100, 4)).astype('float32') - y = np.ones((100, 1)).astype('float32') - dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() - dataset = batch_wrapper(dataset, 4, distribution) - outs = model.evaluate(dataset, steps=10) - self.assertEqual(outs[1], 1.) - self.assertEqual(outs[2], 1.) - - y = np.zeros((100, 1)).astype('float32') - dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() - dataset = batch_wrapper(dataset, 4, distribution) - outs = model.evaluate(dataset, steps=10) - self.assertEqual(outs[1], 0.) - self.assertEqual(outs[2], 0.) - - @combinations.generate(strategy_and_input_combinations()) - def test_correctness(self, distribution, use_numpy, use_validation_data): - with self.cached_session(): - default_tolerance = 1e-5 - tol_table = {} - - if isinstance(distribution, ( - mirrored_strategy.MirroredStrategy, - mirrored_strategy.CoreMirroredStrategy, - distribute_lib._DefaultDistributionStrategy)): # pylint: disable=protected-access - # TODO(b/119257215): Weights are not exactly the same, so use larger - # tolerance for now. Predict should be related to weights. - tol_table = { - 'weights_1': 1e-4, - 'weights_2': 1e-4, - 'predict_result_1': 1e-4, - } - - keras.backend.set_image_data_format('channels_last') - np.random.seed(_RANDOM_SEED) - random_seed.set_random_seed(_RANDOM_SEED) - - # Train, eval, and predict datasets are created with the same input numpy - # arrays. - # TODO(xiejw): Change this back to 10000, once we support final partial - # batch. - num_samples = 9984 - x_train = np.random.rand(num_samples, 1) - y_train = 3 * x_train - x_train = x_train.astype('float32') - y_train = y_train.astype('float32') - x_predict = [[1.], [2.], [3.], [4.]] - - # The model is built once and the initial weights are saved. - # This is used to initialize the model for both the distribution and - # non-distribution run. In addition, we add few non-linear layers to make - # it non-trivial. - def _create_model(): - model = keras.Sequential() - model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) - model.add(keras.layers.Dense(10, activation='relu')) - model.add(keras.layers.Dense(10, activation='relu')) - model.add(keras.layers.Dense(1)) - return model - - model = _create_model() - initial_weights = model.get_weights() - del model # avoid accident usage. - - def _build_and_compile_model(): - # We have initialized the model to the same weight for the distribution - # and non-distribution run. - model = _create_model() - model.set_weights(initial_weights) - model.compile( - loss=keras.losses.mean_squared_error, - optimizer=gradient_descent_keras.SGD(0.5), - metrics=['mse']) - return model - - def fit_eval_and_predict(with_distribution=None): - if with_distribution: - with with_distribution.scope(): - model = _build_and_compile_model() - else: - model = _build_and_compile_model() - - training_inputs, eval_inputs, predict_inputs = ( - get_correctness_test_inputs(use_numpy, use_validation_data, - with_distribution, - x_train, y_train, x_predict)) - - result = {} - result['training_history_1'] = model.fit(**training_inputs).history - - if eval_inputs is not None: - result['eval_result_1'] = model.evaluate(**eval_inputs) - - result['weights_1'] = model.get_weights() - result['predict_result_1'] = model.predict(**predict_inputs) - - # Train and eval again to mimic user's flow. - - result['training_history_2'] = model.fit(**training_inputs).history - - if eval_inputs is not None: - result['eval_result_2'] = model.evaluate(**eval_inputs) - - result['weights_2'] = model.get_weights() - - return result - - results_with_ds = fit_eval_and_predict(with_distribution=distribution) - results_without_ds = fit_eval_and_predict(with_distribution=None) - - # Verify that the weights, training history, eval results, predict outputs - # are the same within some limits of tolerance. - for key in results_with_ds: - if (key.startswith('training_history') and - isinstance(distribution, tpu_strategy.TPUStrategy) and - distribution.extended.steps_per_run > 1): - # TODO(b/119894254): Enable this test for all cases once the - # underlying bug is fixed. - continue - - tolerance = tol_table.get(key, default_tolerance) - - self.assertAllClose( - results_with_ds[key], - results_without_ds[key], - atol=tolerance, - rtol=tolerance, - msg='Fail to assert {}.'.format(key)) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/distribute/python/keras_correctness_test_base.py b/tensorflow/contrib/distribute/python/keras_correctness_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..0c783099b5267d6f57f755ca67dae05099e874d8 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_correctness_test_base.py @@ -0,0 +1,491 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Correctness tests for tf.keras using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np +import six + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import random_seed +from tensorflow.python.keras.engine import distributed_training_utils + +_RANDOM_SEED = 1337 +_EVAL_STEPS = 20 +_GLOBAL_BATCH_SIZE = 64 + +# Note: Please make sure the tests in this file are also covered in +# keras_backward_compat_test for features that are supported with both APIs. + + +all_strategies = [ + combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus, + combinations.tpu_strategy, # steps_per_run=2 + combinations.tpu_strategy_one_step, +] + + +def eager_mode_test_configuration(): + return combinations.combine(mode='eager', + use_numpy=False, + use_validation_data=False) + + +def graph_mode_test_configuration(): + return combinations.combine(mode='graph', + use_numpy=[True, False], + use_validation_data=[True, False]) + + +def all_strategy_and_input_config_combinations(): + return ( + combinations.times( + combinations.combine(distribution=all_strategies), + eager_mode_test_configuration() + graph_mode_test_configuration())) + + +def strategies_for_embedding_models(): + """Returns distribution strategies to test for embedding models. + + Since embedding models take longer to train, we disregard OneDeviceStrategy + and DefaultStrategy in order to prevent testing timeouts. + """ + + strategies = [s for s in all_strategies + if not s.required_tpu and s.required_gpus is not None] + strategies.append(combinations.tpu_strategy_loop_on_device) + strategies.append(combinations.tpu_strategy_one_step_loop_on_device) + return strategies + + +def test_combinations_for_embedding_model(): + return ( + combinations.times( + combinations.combine(distribution= + strategies_for_embedding_models()), + (graph_mode_test_configuration() + + eager_mode_test_configuration()))) + + +def test_combinations_with_tpu_strategies(): + tpu_strategies = [combinations.tpu_strategy_loop_on_device, + combinations.tpu_strategy_one_step_loop_on_device] + + return ( + combinations.times( + combinations.combine(distribution=tpu_strategies), + graph_mode_test_configuration())) + + +class MaybeDistributionScope(object): + """Provides a context allowing no distribution strategy.""" + + def __init__(self, distribution): + self._distribution = distribution + self._scope = None + + def __enter__(self): + if self._distribution: + self._scope = self._distribution.scope() + self._scope.__enter__() + + def __exit__(self, exc_type, value, traceback): + if self._distribution: + self._scope.__exit__(exc_type, value, traceback) + self._scope = None + + +def batch_wrapper(dataset, batch_size, distribution, repeat=None): + if repeat: + dataset = dataset.repeat(repeat) + # TPUs currently require fully defined input shapes, drop_remainder ensures + # the input will have fully defined shapes. + if isinstance(distribution, tpu_strategy.TPUStrategy): + return dataset.batch(batch_size, drop_remainder=True) + else: + return dataset.batch(batch_size) + + +def get_batch_size(global_batch_size, distribution): + batch_size = global_batch_size + # TODO(b/118776054): Use global batch size for Keras/DS support. + use_per_core_batch_size = ( + distribution and + not distributed_training_utils.global_batch_size_supported(distribution)) + if use_per_core_batch_size: + batch_size //= distribution.num_replicas_in_sync + return batch_size + + +def get_data_size(data): + """Gets the size of data in list, tuple, dict, or a numpy array.""" + assert isinstance(data, (np.ndarray, list, dict, tuple)) + + if isinstance(data, np.ndarray): + return len(data) + + if isinstance(data, (list, tuple)): + return len(data[0]) + + return len(six.next(six.itervalues(data))) + + +def get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, x_train, y_train, x_predict): + """Generates the inputs for correctness check when enable Keras with DS.""" + training_epochs = 2 + global_batch_size = _GLOBAL_BATCH_SIZE + batch_size = get_batch_size(global_batch_size, with_distribution) + + if use_numpy: + training_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + 'epochs': training_epochs, + 'shuffle': False, + } + + if use_validation_data: + eval_inputs = None + training_inputs['validation_data'] = (x_train, y_train) + else: + eval_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + } + predict_inputs = { + 'x': x_predict + } + else: + training_data_size = get_data_size(x_train) + if training_data_size < _GLOBAL_BATCH_SIZE * _EVAL_STEPS: + # Currently, we cannot detect the size of a dataset. So, the eval steps is + # hard coded. + raise ValueError('x_train must have at least ' + '_GLOBAL_BATCH_SIZE * _EVAL_STEPS samples') + # For dataset inputs, we do not pass batch_size to + # keras.fit/evaluate/predict. The batch size is part of the dataset. + train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + x = batch_wrapper(train_dataset, batch_size, with_distribution, + repeat=training_epochs) + + training_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'epochs': training_epochs, + 'shuffle': False, + 'steps_per_epoch': training_data_size // global_batch_size, + } + if use_validation_data: + eval_inputs = None # Remove the eval_inputs + eval_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + x = batch_wrapper(eval_dataset, batch_size, with_distribution) + training_inputs['validation_data'] = x + training_inputs['validation_steps'] = 5 + else: + eval_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'steps': _EVAL_STEPS, + } + + predict_batch_size = get_batch_size(get_data_size(x_predict), + with_distribution) + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) + predict_dataset = batch_wrapper(predict_dataset, predict_batch_size, + with_distribution) + predict_inputs = { + 'steps': 1, + 'x': predict_dataset, + } + + return training_inputs, eval_inputs, predict_inputs + + +def fit_eval_and_predict(initial_weights, input_fn, model_fn, + distribution=None, is_stateful_model=False): + """Generates results for fit/predict/evaluate for given model.""" + model = model_fn(initial_weights=initial_weights, distribution=distribution) + training_inputs, eval_inputs, predict_inputs = input_fn(distribution) + + result = {} + result['training_history_1'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_1'] = model.evaluate(**eval_inputs) + + result['weights_1'] = model.get_weights() + + if predict_inputs is not None: + # Check correctness of the result of predict() invoked + # multiple times -- as for stateful models, result of + # predict may differ for each batch. + predict_length = 1 + if is_stateful_model: + predict_length = 3 + for i in range(predict_length): + result_key = 'predict_result_{}'.format(i) + result[result_key] = model.predict(**predict_inputs) + + # Train and eval again to mimic user's flow. + + result['training_history_2'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_2'] = model.evaluate(**eval_inputs) + + result['weights_2'] = model.get_weights() + + return result + + +def compare_results(results_with_ds, results_without_ds, distribution, + testcase): + """Compares results of model compiled with/without distribution strategy.""" + + default_tolerance = 1e-5 + relaxed_tolerance = 1e-4 + + def _get_compare_result_tolerance(key): + """Returns tolerance to compare results.""" + # TODO(b/119257215): For MirroredStrategy, weights are not exactly the same, + # so use larger tolerance for now. Predict should be related to weights. + if (isinstance(distribution, ( + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, + distribute_lib._DefaultDistributionStrategy)) and # pylint: disable=protected-access + key.startswith(('weights_1', 'weights_2', 'predict_result'))): + return relaxed_tolerance + + return default_tolerance + + for key in results_with_ds: + if (key.startswith('training_history') and + isinstance(distribution, tpu_strategy.TPUStrategy) and + distribution.extended.steps_per_run > 1): + # TODO(b/119894254): Enable this test for all cases once the + # underlying bug is fixed. + continue + + tolerance = _get_compare_result_tolerance(key) + testcase.assertAllClose( + results_with_ds[key], + results_without_ds[key], + atol=tolerance, + rtol=tolerance, + msg='Fail to assert {}.'.format(key)) + + +def should_skip_tpu_with_eager(distribution): + return (context.executing_eagerly() and + isinstance(distribution, tpu_strategy.TPUStrategy)) + + +class LearningRateBatchScheduler(keras.callbacks.Callback): + """Scheduler that dynamically sets the learning rate of model.""" + + def __init__(self, update_freq=None): + self._update_freq = update_freq + + def on_batch_begin(self, batch, logs=None): + if self._update_freq and batch % self._update_freq != 0: + return + + # To avoid divergence, limit the value range. + lr = 0.001 * (batch % 10) + keras.backend.set_value(self.model.optimizer.lr, lr) + + +class TestDistributionStrategyCorrectnessBase(test.TestCase, + parameterized.TestCase): + """Model agnostic testing infra to test correctness of Keras models.""" + + def set_up_test_config(self, use_numpy=False, + use_validation_data=False, + with_batch_norm=False): + self.use_numpy = use_numpy + self.use_validation_data = use_validation_data + self.with_batch_norm = with_batch_norm + + keras.backend.set_image_data_format('channels_last') + np.random.seed(_RANDOM_SEED) + random_seed.set_random_seed(_RANDOM_SEED) + + def get_data(self): + num_samples = 10000 + x_train = np.random.randint(0, 2, num_samples) + x_train = np.reshape(x_train, (num_samples, 1)) + y_train = x_train + return (x_train.astype('float32'), y_train.astype('float32'), None) + + def get_model(self, distribution=None): + raise NotImplementedError + + def skip_unsupported_test_configuration(self, distribution): + if should_skip_tpu_with_eager(distribution): + self.skipTest('TPUStrategy does not support eager mode now.') + + if context.executing_eagerly() and self.use_numpy: + self.skipTest('Numpy as inputs is not supported with strategy in eager.') + + if context.executing_eagerly() and self.use_validation_data: + self.skipTest('TODO(hongjunchoi): Add test logic for using validation ' + 'data for eager execution.') + return + + def run_correctness_test(self, + distribution, + use_numpy, + use_validation_data, + with_batch_norm=False, + is_stateful_model=False): + with self.cached_session(): + self.set_up_test_config(use_numpy, use_validation_data, with_batch_norm) + self.skip_unsupported_test_configuration(distribution) + + # Train, eval, and predict datasets are created with the same input numpy + # arrays. + x_train, y_train, x_predict = self.get_data() + + # The model is built once and the initial weights are saved. + # This is used to initialize the model for both the distribution and + # non-distribution run. + model = self.get_model() + initial_weights = model.get_weights() + + def input_fn(dist): + return get_correctness_test_inputs( + use_numpy, use_validation_data, dist, x_train, y_train, x_predict) + + results_with_ds = fit_eval_and_predict( + initial_weights, input_fn=input_fn, model_fn=self.get_model, + distribution=distribution, is_stateful_model=is_stateful_model) + results_without_ds = fit_eval_and_predict( + initial_weights, input_fn=input_fn, model_fn=self.get_model, + distribution=None, is_stateful_model=is_stateful_model) + + # First, special case, for multi-replica distributed training, batch norm + # is not aggregated globally. So it is expected to have different weights. + if (self.with_batch_norm and + distribution.num_replicas_in_sync > 1): + with self.assertRaises(AssertionError): + compare_results(results_with_ds, results_without_ds, distribution, + testcase=self) + else: + compare_results(results_with_ds, results_without_ds, distribution, + testcase=self) + + def run_dynamic_lr_test(self, distribution): + with self.cached_session(): + self.set_up_test_config() + self.skip_unsupported_test_configuration(distribution) + + x_train, y_train, _ = self.get_data() + model = self.get_model() + initial_weights = model.get_weights() + update_freq = None + + if (isinstance(distribution, tpu_strategy.TPUStrategy) and + distribution.extended.steps_per_run > 1): + # For TPUStrategy with steps_per_run > 1, the callback is not invoked + # every step. So, to compare the CPU/TPU, we let the CPU to behave the + # same as TPU. + update_freq = distribution.extended.steps_per_run + + def input_fn(dist): + """Generates training test given test configuration.""" + training_epochs = 2 + global_batch_size = 64 + batch_size = get_batch_size(global_batch_size, dist) + + training_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + 'epochs': training_epochs, + 'shuffle': False, + 'callbacks': [LearningRateBatchScheduler(update_freq)], + 'validation_data': (x_train, y_train) + } + # In this test case, we do not care eval and predict. + eval_inputs, predict_inputs = None, None + return training_inputs, eval_inputs, predict_inputs + + results_with_ds = fit_eval_and_predict( + initial_weights, input_fn=input_fn, model_fn=self.get_model, + distribution=distribution) + results_without_ds = fit_eval_and_predict( + initial_weights, input_fn=input_fn, model_fn=self.get_model, + distribution=None) + compare_results(results_with_ds, results_without_ds, distribution, + testcase=self) + + +class TestDistributionStrategyEmbeddingModelCorrectnessBase( + TestDistributionStrategyCorrectnessBase): + """Base class to test correctness of Keras models with embedding layers.""" + + def get_data(self, + count=(_GLOBAL_BATCH_SIZE * _EVAL_STEPS), + min_words=5, + max_words=10, + max_word_id=19, + num_classes=2): + distribution = [] + for _ in range(num_classes): + dist = np.abs(np.random.randn(max_word_id)) + dist /= np.sum(dist) + distribution.append(dist) + + features = [] + labels = [] + for _ in range(count): + label = np.random.randint(0, num_classes, size=1)[0] + num_words = np.random.randint(min_words, max_words, size=1)[0] + word_ids = np.random.choice( + max_word_id, size=num_words, replace=True, p=distribution[label]) + word_ids = word_ids + labels.append(label) + features.append(word_ids) + + features = keras.preprocessing.sequence.pad_sequences( + features, maxlen=max_words) + x_train = np.asarray(features, dtype=np.float32) + y_train = np.asarray(labels, dtype=np.int32).reshape((count, 1)) + x_predict = x_train[:_GLOBAL_BATCH_SIZE] + return x_train, y_train, x_predict + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py b/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dae32188917cce9209b8e51032ef808352bc257c --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py @@ -0,0 +1,171 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Correctness tests for tf.keras DNN model using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import keras_correctness_test_base +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import test +from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras +from tensorflow.python.training import gradient_descent + + +def all_strategy_combinations_with_eager_and_graph_modes(): + return combinations.combine(distribution=keras_correctness_test_base. + all_strategies, + mode=['graph', 'eager']) + + +def all_strategy_combinations_with_graph_mode(): + return combinations.combine(distribution=keras_correctness_test_base. + all_strategies, mode=['graph']) + + +class TestDistributionStrategyDnnCorrectness( + keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): + + def get_model(self, initial_weights=None, distribution=None): + with keras_correctness_test_base.MaybeDistributionScope(distribution): + # We add few non-linear layers to make it non-trivial. + model = keras.Sequential() + model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(1)) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + loss=keras.losses.mean_squared_error, + optimizer=gradient_descent_keras.SGD(0.5), + metrics=['mse']) + return model + + def get_data(self): + # TODO(xiejw): Change this back to 10000, once we support final partial + # batch. + num_samples = 9984 + x_train = np.random.rand(num_samples, 1) + y_train = 3 * x_train + x_train = x_train.astype('float32') + y_train = y_train.astype('float32') + x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32) + return x_train, y_train, x_predict + + @combinations.generate(keras_correctness_test_base. + all_strategy_and_input_config_combinations()) + def test_dnn_correctness(self, distribution, use_numpy, use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data) + + @combinations.generate(all_strategy_combinations_with_graph_mode()) + def test_dnn_with_dynamic_learning_rate(self, distribution): + self.run_dynamic_lr_test(distribution) + + +class TestDistributionStrategyDnnMetricCorrectness( + keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): + + def get_model(self, distribution=None): + with distribution.scope(): + model = keras.Sequential() + model.add(keras.layers.Dense(1, + input_shape=(1,), + kernel_initializer='ones')) + model.compile( + loss=keras.losses.mean_squared_error, + optimizer=gradient_descent.GradientDescentOptimizer(0.5), + metrics=[keras.metrics.BinaryAccuracy()]) + return model + + def run_metric_correctness_test(self, distribution): + with self.cached_session(): + self.set_up_test_config() + self.skip_unsupported_test_configuration(distribution) + + x_train, y_train, _ = self.get_data() + model = self.get_model(distribution=distribution) + + batch_size = 64 + batch_size = (keras_correctness_test_base. + get_batch_size(batch_size, distribution)) + train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + train_dataset = (keras_correctness_test_base. + batch_wrapper(train_dataset, batch_size, distribution)) + + history = model.fit(x=train_dataset, epochs=2, steps_per_epoch=10) + self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0]) + + @combinations.generate(all_strategy_combinations_with_eager_and_graph_modes()) + def test_simple_dnn_metric_correctness(self, distribution): + self.run_metric_correctness_test(distribution) + + +class TestDistributionStrategyDnnMetricEvalCorrectness( + keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): + + def get_model(self, distribution=None): + with distribution.scope(): + model = keras.Sequential() + model.add( + keras.layers.Dense( + 3, activation='relu', input_dim=4, kernel_initializer='ones')) + model.add( + keras.layers.Dense( + 1, activation='sigmoid', kernel_initializer='ones')) + model.compile( + loss='mae', + metrics=['accuracy', keras.metrics.BinaryAccuracy()], + optimizer=gradient_descent.GradientDescentOptimizer(0.001)) + return model + + def run_eval_metrics_correctness_test(self, distribution): + with self.cached_session(): + self.set_up_test_config() + self.skip_unsupported_test_configuration(distribution) + + model = self.get_model(distribution=distribution) + + # verify correctness of stateful and stateless metrics. + x = np.ones((100, 4)).astype('float32') + y = np.ones((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = (keras_correctness_test_base. + batch_wrapper(dataset, 4, distribution)) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 1.) + self.assertEqual(outs[2], 1.) + + y = np.zeros((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = (keras_correctness_test_base. + batch_wrapper(dataset, 4, distribution)) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 0.) + self.assertEqual(outs[2], 0.) + + @combinations.generate(all_strategy_combinations_with_eager_and_graph_modes()) + def test_identity_model_metric_eval_correctness(self, distribution): + self.run_eval_metrics_correctness_test(distribution) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_embedding_model_correctness_test.py b/tensorflow/contrib/distribute/python/keras_embedding_model_correctness_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e881bb70ecc428e3f972cde5f19c1b61b1dc0f0b --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_embedding_model_correctness_test.py @@ -0,0 +1,150 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Correctness test for tf.keras Embedding models using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import keras_correctness_test_base +from tensorflow.python import keras +from tensorflow.python.eager import test +from tensorflow.python.training import gradient_descent + + +class DistributionStrategyEmbeddingModelCorrectnessTest( + keras_correctness_test_base. + TestDistributionStrategyEmbeddingModelCorrectnessBase): + + def get_model(self, max_words=10, initial_weights=None, distribution=None): + with keras_correctness_test_base.MaybeDistributionScope(distribution): + word_ids = keras.layers.Input( + shape=(max_words,), dtype=np.int32, name='words') + word_embed = keras.layers.Embedding(input_dim=20, + output_dim=10)(word_ids) + if self.use_distributed_dense: + word_embed = keras.layers.TimeDistributed(keras.layers.Dense(4))( + word_embed) + avg = keras.layers.GlobalAveragePooling1D()(word_embed) + preds = keras.layers.Dense(2, activation='softmax')(avg) + model = keras.Model(inputs=[word_ids], outputs=[preds]) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + optimizer=gradient_descent.GradientDescentOptimizer( + learning_rate=0.1), + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy']) + return model + + @combinations.generate(keras_correctness_test_base. + test_combinations_for_embedding_model()) + def test_embedding_model_correctness(self, distribution, use_numpy, + use_validation_data): + + self.use_distributed_dense = False + self.run_correctness_test(distribution, use_numpy, use_validation_data) + + @combinations.generate(keras_correctness_test_base. + test_combinations_for_embedding_model()) + def test_embedding_time_distributed_model_correctness(self, + distribution, + use_numpy, + use_validation_data): + self.use_distributed_dense = True + self.run_correctness_test(distribution, use_numpy, use_validation_data) + + +class DistributionStrategySiameseEmbeddingModelCorrectnessTest( + keras_correctness_test_base. + TestDistributionStrategyEmbeddingModelCorrectnessBase): + + def get_model(self, max_words=10, initial_weights=None, distribution=None): + with keras_correctness_test_base.MaybeDistributionScope(distribution): + word_ids_a = keras.layers.Input( + shape=(max_words,), dtype=np.int32, name='words_a') + word_ids_b = keras.layers.Input( + shape=(max_words,), dtype=np.int32, name='words_b') + + def submodel(embedding, word_ids): + word_embed = embedding(word_ids) + rep = keras.layers.GlobalAveragePooling1D()(word_embed) + return keras.Model(inputs=[word_ids], outputs=[rep]) + + word_embed = keras.layers.Embedding( + input_dim=20, + output_dim=10, + input_length=max_words, + embeddings_initializer=keras.initializers.RandomUniform(0, 1)) + + a_rep = submodel(word_embed, word_ids_a).outputs[0] + b_rep = submodel(word_embed, word_ids_b).outputs[0] + sim = keras.layers.Dot(axes=1, normalize=True)([a_rep, b_rep]) + + model = keras.Model(inputs=[word_ids_a, word_ids_b], outputs=[sim]) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + optimizer=gradient_descent.GradientDescentOptimizer( + learning_rate=0.1), + loss='mse', + metrics=['mse']) + return model + + def get_data(self, + count=(keras_correctness_test_base._GLOBAL_BATCH_SIZE * + keras_correctness_test_base._EVAL_STEPS), + min_words=5, + max_words=10, + max_word_id=19, + num_classes=2): + features_a, labels_a, _ = (super( + DistributionStrategySiameseEmbeddingModelCorrectnessTest, self). + get_data(count, min_words, max_words, + max_word_id, num_classes)) + + features_b, labels_b, _ = (super( + DistributionStrategySiameseEmbeddingModelCorrectnessTest, self). + get_data(count, min_words, max_words, + max_word_id, num_classes)) + + y_train = np.zeros((count, 1), dtype=np.float32) + y_train[labels_a == labels_b] = 1.0 + y_train[labels_a != labels_b] = -1.0 + # TODO(b/123360757): Add tests for using list as inputs for multi-input + # models. + x_train = { + 'words_a': features_a, + 'words_b': features_b, + } + x_predict = x_train + + return x_train, y_train, x_predict + + @combinations.generate(keras_correctness_test_base. + test_combinations_for_embedding_model()) + def test_siamese_embedding_model_correctness(self, distribution, use_numpy, + use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_image_model_correctness_test.py b/tensorflow/contrib/distribute/python/keras_image_model_correctness_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f625664372dfb6814ccbe9539f6abe018d2a4447 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_image_model_correctness_test.py @@ -0,0 +1,92 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Correctness tests for tf.keras CNN models using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import keras_correctness_test_base +from tensorflow.python import keras +from tensorflow.python.eager import test +from tensorflow.python.training import gradient_descent + + +class DistributionStrategyCnnCorrectnessTest( + keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): + + def get_model(self, initial_weights=None, distribution=None): + with keras_correctness_test_base.MaybeDistributionScope(distribution): + image = keras.layers.Input(shape=(28, 28, 3), name='image') + c1 = keras.layers.Conv2D( + name='conv1', filters=16, kernel_size=(3, 3), strides=(4, 4))( + image) + if self.with_batch_norm: + c1 = keras.layers.BatchNormalization(name='bn1')(c1) + c1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c1) + logits = keras.layers.Dense( + 10, activation='softmax', name='pred')( + keras.layers.Flatten()(c1)) + model = keras.Model(inputs=[image], outputs=[logits]) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + optimizer=gradient_descent.GradientDescentOptimizer( + learning_rate=0.1), + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy']) + + return model + + def get_data(self, + count=keras_correctness_test_base._GLOBAL_BATCH_SIZE + * keras_correctness_test_base._EVAL_STEPS, + shape=(28, 28, 3), + num_classes=10): + centers = np.random.randn(num_classes, *shape) + + features = [] + labels = [] + for _ in range(count): + label = np.random.randint(0, num_classes, size=1)[0] + offset = np.random.normal(loc=0, scale=0.1, size=np.prod(shape)) + offset = offset.reshape(shape) + labels.append(label) + features.append(centers[label] + offset) + + x_train = np.asarray(features, dtype=np.float32) + y_train = np.asarray(labels, dtype=np.float32).reshape((count, 1)) + x_predict = x_train + return x_train, y_train, x_predict + + @combinations.generate(keras_correctness_test_base. + all_strategy_and_input_config_combinations()) + def test_cnn_correctness(self, distribution, use_numpy, use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data) + + @combinations.generate(keras_correctness_test_base. + all_strategy_and_input_config_combinations()) + def test_cnn_with_batch_norm_correctness(self, distribution, use_numpy, + use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data, + with_batch_norm=True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_lstm_model_correctness_test.py b/tensorflow/contrib/distribute/python/keras_lstm_model_correctness_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed2dfa206cdf4be24a88b1d54090487c1873399 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_lstm_model_correctness_test.py @@ -0,0 +1,65 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Correctness tests for tf.keras LSTM model using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import keras_correctness_test_base +from tensorflow.python import keras +from tensorflow.python.eager import test +from tensorflow.python.training import gradient_descent + + +class DistributionStrategyLstmModelCorrectnessTest( + keras_correctness_test_base. + TestDistributionStrategyEmbeddingModelCorrectnessBase): + + def get_model(self, max_words=10, initial_weights=None, distribution=None): + with keras_correctness_test_base.MaybeDistributionScope(distribution): + word_ids = keras.layers.Input( + shape=(max_words,), dtype=np.int32, name='words') + word_embed = keras.layers.Embedding(input_dim=20, + output_dim=10)(word_ids) + lstm_embed = keras.layers.LSTM(units=4, + return_sequences=False)(word_embed) + + preds = keras.layers.Dense(2, activation='softmax')(lstm_embed) + model = keras.Model(inputs=[word_ids], outputs=[preds]) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + optimizer=gradient_descent.GradientDescentOptimizer( + learning_rate=0.1), + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy']) + return model + + @combinations.generate(keras_correctness_test_base. + test_combinations_for_embedding_model()) + def test_lstm_model_correctness(self, + distribution, + use_numpy, + use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py index cce93b3c10a2ac7bd1c594a5027b9d51629bb915..5349794334b7f6ea3d718343fa84c693dd3d7a3c 100644 --- a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py @@ -65,7 +65,8 @@ class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): devices = ['/device:GPU:0', '/device:CPU:0'] with distribution.scope(): - (var, m, v, op, counter) = distribution.call_for_each_replica(create_fn) + (var, m, v, op, + counter) = distribution.extended.call_for_each_replica(create_fn) self.evaluate(variables.global_variables_initializer()) var_val = [2.0, 2.0, 2.0] self.assertAllClose( diff --git a/tensorflow/contrib/distribute/python/keras_stateful_lstm_model_correctness_test.py b/tensorflow/contrib/distribute/python/keras_stateful_lstm_model_correctness_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ab56c01d862354bd74330f769502692bd8a8b982 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_stateful_lstm_model_correctness_test.py @@ -0,0 +1,99 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for stateful tf.keras LSTM models using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import keras_correctness_test_base +from tensorflow.python import keras +from tensorflow.python.eager import test +from tensorflow.python.training import gradient_descent + + +def strategies_for_stateful_embedding_model(): + """Returns TPUStrategy with single core device assignment.""" + + return [combinations.tpu_strategy_loop_on_device_one_core, + combinations.tpu_strategy_one_step_loop_on_device_one_core] + + +def test_combinations_for_stateful_embedding_model(): + return ( + combinations.combine( + distribution=strategies_for_stateful_embedding_model(), + mode='graph', + use_numpy=False, + use_validation_data=False + )) + + +class DistributionStrategyStatefulLstmModelCorrectnessTest( + keras_correctness_test_base. + TestDistributionStrategyEmbeddingModelCorrectnessBase): + + def get_model(self, max_words=10, initial_weights=None, distribution=None): + batch_size = keras_correctness_test_base._GLOBAL_BATCH_SIZE + + with keras_correctness_test_base.MaybeDistributionScope(distribution): + word_ids = keras.layers.Input( + shape=(max_words,), + batch_size=batch_size, + dtype=np.int32, name='words') + word_embed = keras.layers.Embedding(input_dim=20, + output_dim=10)(word_ids) + lstm_embed = keras.layers.LSTM(units=4, + return_sequences=False, + stateful=True)(word_embed) + + preds = keras.layers.Dense(2, activation='softmax')(lstm_embed) + model = keras.Model(inputs=[word_ids], outputs=[preds]) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + optimizer=gradient_descent.GradientDescentOptimizer( + learning_rate=0.1), + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy']) + return model + + @combinations.generate(test_combinations_for_stateful_embedding_model()) + def test_stateful_lstm_model_correctness(self, + distribution, + use_numpy, + use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data, + is_stateful_model=True) + + @combinations.generate(keras_correctness_test_base. + test_combinations_with_tpu_strategies()) + def test_incorrectly_use_multiple_cores_for_stateful_lstm_model( + self, distribution, use_numpy, use_validation_data): + with self.assertRaisesRegexp(ValueError, + 'Single core must be used for computation ' + 'on stateful models. Consider adding ' + '`device_assignment` parameter to ' + 'TPUStrategy'): + self.run_correctness_test(distribution, use_numpy, use_validation_data, + is_stateful_model=True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 84e9aea228352e0a6010fe95529407818d020b5f..17ed87145984af96073c78cf4974527e558d3842 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import os +import tempfile from absl.testing import parameterized import numpy as np @@ -245,15 +246,32 @@ def all_strategy_combinations(): return strategy_minus_tpu_combinations() + tpu_strategy_combinations() -# TODO(priyag): Add v2 optimizers here. +def all_strategy_combinations_minus_default(): + strategy_minus_default_combinations = combinations.combine( + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph', 'eager']) + return strategy_minus_default_combinations + tpu_strategy_combinations() + + def strategy_and_optimizer_combinations(): + # TODO(b/122372746): Uncomment optimizers after they pass tests. return combinations.times( all_strategy_combinations(), - combinations.combine( - optimizer=[combinations.adagrad_optimizer_v1_fn, - combinations.adam_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v1_fn, - combinations.rmsprop_optimizer_v1_fn])) + combinations.combine(optimizer=[ + combinations.adagrad_optimizer_v1_fn, + # combinations.adagrad_optimizer_keras_v2_fn, + combinations.adam_optimizer_v1_fn, + combinations.adam_optimizer_keras_v2_fn, + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_keras_v2_fn, + combinations.rmsprop_optimizer_v1_fn, + # combinations.rmsprop_optimizer_keras_v2_fn + ])) def strategy_for_numpy_input_combinations(): @@ -417,15 +435,6 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, class TestDistributionStrategyWithNumpyArrays(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_for_numpy_input_combinations()) - def test_creating_var_with_numpy_arrays(self, distribution): - with self.cached_session(): - x = np.asarray(np.random.random((64, 3)), dtype=np.float32) - var_x = distributed_training_utils.get_var_for_numpy(distribution, x) - val = self.evaluate(var_x.value()) - # Verify that the numpy value is copied to the variable. - self.assertAllEqual(x, val) - @combinations.generate(strategy_for_numpy_input_combinations()) def test_calculating_input_params_no_steps_no_batch_size(self, distribution): # Calculate the per_replica_batch_size scaling factor for strategies @@ -564,26 +573,26 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, metrics = ['mae'] model.compile(optimizer, loss, metrics=metrics) - inputs = np.zeros((64, 3), dtype=np.float32) - targets = np.zeros((64, 4), dtype=np.float32) + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) - # Call fit with validation data - model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0, - validation_data=(inputs, targets)) + # Call fit with validation data + model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0, + validation_data=(inputs, targets)) - # TODO(anjalisridhar): We need tests for when the batch size and steps are - # smaller and results in a 0 batch_size and steps value. - model.evaluate(inputs, targets) - # with steps - model.evaluate(inputs, targets, steps=2) - # with batch_size - model.evaluate(inputs, targets, batch_size=8) + # TODO(anjalisridhar): We need tests for when the batch size and steps + # are smaller and results in a 0 batch_size and steps value. + model.evaluate(inputs, targets) + # with steps + model.evaluate(inputs, targets, steps=2) + # with batch_size + model.evaluate(inputs, targets, batch_size=8) - model.predict(inputs) - # with steps - model.predict(inputs, steps=2) - # with batch_size - model.predict(inputs, batch_size=8) + model.predict(inputs) + # with steps + model.predict(inputs, steps=2) + # with batch_size + model.predict(inputs, batch_size=8) @combinations.generate(strategy_for_numpy_input_combinations()) def test_calling_model_with_nested_numpy_arrays(self, distribution): @@ -937,9 +946,6 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(all_strategy_combinations()) def testOptimizerWithCallbacks(self, distribution): with self.cached_session(): - # TODO(b/120946189): Investigate why default strategy + eager fails. - if '_Default' in distribution.__class__.__name__: - self.skipTest('Disable the test for default strategy.') with distribution.scope(): model = get_model() optimizer = gradient_descent_keras.SGD(0.01) @@ -1045,14 +1051,12 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): # Test with not specifying the `steps` argument. with self.assertRaisesRegexp( - ValueError, 'you should specify the `steps_per_epoch` argument'): + ValueError, 'the `steps_per_epoch` argument'): model.fit(dataset, epochs=1, verbose=0) - with self.assertRaisesRegexp(ValueError, - 'you should specify the `steps` argument'): + with self.assertRaisesRegexp(ValueError, 'the `steps` argument'): model.evaluate(dataset, verbose=0) - with self.assertRaisesRegexp(ValueError, - 'you should specify the `steps` argument'): + with self.assertRaisesRegexp(ValueError, 'the `steps` argument'): model.predict(dataset, verbose=0) @combinations.generate(combinations.combine( @@ -1119,12 +1123,15 @@ class TestDistributionStrategyWithLossMasking(test.TestCase, class TestDistributionStrategyWithNormalizationLayer( test.TestCase, parameterized.TestCase): - @combinations.generate(all_strategy_combinations()) - def test_batchnorm_correctness(self, distribution): + @combinations.generate(combinations.times( + all_strategy_combinations(), + combinations.combine(fused=[True, False]))) + def test_batchnorm_correctness(self, distribution, fused): with self.cached_session(): with distribution.scope(): model = keras.models.Sequential() - norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8) + norm = keras.layers.BatchNormalization( + input_shape=(10,), momentum=0.8, fused=fused) model.add(norm) model.compile(loss='mse', optimizer=gradient_descent.GradientDescentOptimizer(0.01)) @@ -1148,5 +1155,78 @@ class TestDistributionStrategyWithNormalizationLayer( np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) +class TestDistributionStrategySaveLoadWeights(test.TestCase, + parameterized.TestCase): + + @combinations.generate(all_strategy_combinations_minus_default()) + def test_save_load_h5(self, distribution): + with self.cached_session(): + dataset = get_dataset(distribution) + with distribution.scope(): + model = get_model() + model.compile(gradient_descent_keras.SGD(0.01), 'mse') + model.fit(dataset, epochs=1, steps_per_epoch=1) + + weights_file = tempfile.mktemp('.h5') + model.save_weights(weights_file) + + model_2 = get_model() + model_2.compile(gradient_descent_keras.SGD(0.01), 'mse') + model_2.load_weights(weights_file) + model_2.predict(get_predict_dataset(distribution), steps=2) + model_2.fit(dataset, epochs=1, steps_per_epoch=1) + + @combinations.generate(all_strategy_combinations_minus_default()) + def test_save_load_checkpointable(self, distribution): + # TODO(sourabhbajaj): Test fails with optimizer v2 without h5 + with self.cached_session(): + dataset = get_dataset(distribution) + with distribution.scope(): + model = get_model() + model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') + model.fit(dataset, epochs=1, steps_per_epoch=1) + + weights_file = tempfile.mktemp() + model.save_weights(weights_file) + + model_2 = get_model() + model_2.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') + model_2.load_weights(weights_file) + model_2.predict(get_predict_dataset(distribution), steps=2) + model_2.fit(dataset, epochs=1, steps_per_epoch=1) + + +class TestDistributionStrategyValidation(test.TestCase, + parameterized.TestCase): + + @combinations.generate(all_strategy_combinations_minus_default()) + def test_layer_outside_scope(self, distribution): + with self.cached_session(): + with self.assertRaisesRegexp( + ValueError, 'was not created in the distribution strategy'): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + with distribution.scope(): + model = keras.Model(x, y) + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + @combinations.generate(all_strategy_combinations_minus_default()) + def test_model_outside_scope(self, distribution): + with self.cached_session(): + with self.assertRaisesRegexp( + ValueError, 'was not created in the distribution strategy'): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + with distribution.scope(): + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index 32a0d199434e0627122fd4e47cf8894079ef3a1e..a663e809dd45ea099e1d8a08e681d07b05bee3c9 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -95,16 +95,15 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn): with ops.Graph().as_default(), distribution.scope(): - iterator = distribution.distribute_dataset( - dataset_fn).make_initializable_iterator() + iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn()) if isinstance(distribution, tpu_strategy.TPUStrategy): def step_fn(ctx, inputs): - value, update = distribution.call_for_each_replica( + value, update = distribution.extended.call_for_each_replica( metric_fn, args=(inputs,)) ctx.set_non_tensor_output(name="value", output=value) return distribution.group(update) - ctx = distribution.run_steps_on_dataset( + ctx = distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=distribution.extended.steps_per_run) update = ctx.run_op value = ctx.non_tensor_outputs["value"] @@ -114,15 +113,14 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): distribution.num_replicas_in_sync * distribution.extended.steps_per_run) else: - value, update = distribution.call_for_each_replica( + value, update = distribution.extended.call_for_each_replica( metric_fn, args=(iterator.get_next(),)) update = distribution.group(update) # TODO(josh11b): Once we switch to using a global batch size for input, # replace "distribution.num_replicas_in_sync" with "1". batches_per_update = distribution.num_replicas_in_sync - self.evaluate(iterator.initializer) - self.evaluate(distribution.initialize()) + self.evaluate(iterator.initialize()) self.evaluate(variables.local_variables_initializer()) batches_consumed = 0 @@ -136,8 +134,6 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): if batches_consumed >= 4: # Consume 4 input batches in total. break - self.evaluate(distribution.finalize()) - @combinations.generate(all_combinations() + tpu_combinations()) def testMean(self, distribution): def _dataset_fn(): diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index 824c4b09371fcc8d590f2d2b2be8f39b4a585b27..f06c9b75644b2890b7657f75e74e4e20a6f15705 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -41,12 +41,9 @@ from tensorflow.python.ops.losses import losses_impl class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): - def _get_iterator(self, ds): - if context.executing_eagerly(): - iterator = ds.make_one_shot_iterator() - else: - iterator = ds.make_initializable_iterator() - self.evaluate(iterator.initializer) + def _get_iterator(self, strategy, input_fn): + iterator = strategy.make_input_fn_iterator(lambda _: input_fn()) + self.evaluate(iterator.initialize()) return iterator @combinations.generate( @@ -67,15 +64,15 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica(model_fn, args=(inputs,))) + distribution.extended.call_for_each_replica( + model_fn, args=(inputs,))) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = self._get_iterator(distribution, dataset_fn) def run_step(): - return distribution.run_steps_on_dataset( + return distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=2).run_op - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) @@ -84,12 +81,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): weights, biases = [], [] for _ in range(5): run_step() - weights.append(self.evaluate(layer.kernel)) biases.append(self.evaluate(layer.bias)) - self.evaluate(distribution.finalize()) - error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) self.assertTrue(is_not_increasing) @@ -105,11 +99,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = self._get_iterator(distribution, dataset_fn) def run_step(): return distribution.group( - distribution.call_for_each_replica( + distribution.extended.call_for_each_replica( model_fn, args=(iterator.get_next(),))) if not context.executing_eagerly(): @@ -152,7 +146,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # `distribution.scope`. with variable_scope.variable_creator_scope( appending_creator), distribution.scope(): - model_fn, dataset_fn, layer = minimize_loss_example( + model_fn, dataset_fn, _ = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=True, @@ -161,24 +155,21 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica(model_fn, args=(inputs,))) + distribution.extended.call_for_each_replica( + model_fn, args=(inputs,))) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = self._get_iterator(distribution, dataset_fn) def run_step(): - return distribution.run_steps_on_dataset( + return distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=1).run_op - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) - run_step() - self.evaluate(distribution.finalize()) - def get_expected_variables(optimizer_fn, num_parameter_devices): variables_map = { "GradientDescent": ["dense/kernel", "dense/bias"], @@ -197,7 +188,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): self.assertEqual( get_expected_variables(optimizer_fn, - len(distribution.parameter_devices)), + len(distribution.extended.parameter_devices)), set(created_variables)) @combinations.generate( @@ -230,18 +221,18 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused fetches = distribution.unwrap( - distribution.call_for_each_replica(model_fn, args=(inputs,))) + distribution.extended.call_for_each_replica( + model_fn, args=(inputs,))) if update_ops_in_cross_replica_mode: fetches += tuple(ops.get_collection(ops.GraphKeys.UPDATE_OPS)) return control_flow_ops.group(fetches) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = self._get_iterator(distribution, dataset_fn) def run_step(): - return distribution.run_steps_on_dataset( + return distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=1).run_op - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) @@ -267,8 +258,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) - self.evaluate(distribution.finalize()) - @combinations.generate( combinations.times( combinations.combine( @@ -327,15 +316,15 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica(model_fn, args=(inputs,))) + distribution.extended.call_for_each_replica( + model_fn, args=(inputs,))) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = self._get_iterator(distribution, dataset_fn) def run_step(): - return distribution.run_steps_on_dataset( + return distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=1).run_op - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) @@ -370,8 +359,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # One of the mean loss reductions. self.assertNear(weight, 2 + 10.6, 0.0001) - self.evaluate(distribution.finalize()) - @combinations.generate( combinations.times( combinations.distributions_and_v1_optimizers(), @@ -412,7 +399,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): return (train_op, loss) def step_fn(output_context, inputs): - (train_op, loss) = distribution.call_for_each_replica( + (train_op, loss) = distribution.extended.call_for_each_replica( model_fn, args=(output_context, inputs)) output_context.set_last_step_output( name="cross_replica_loss_reduced", @@ -423,7 +410,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): output=loss) return distribution.group(train_op) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = self._get_iterator(distribution, dataset_fn) def run_step(): initial_loss = lambda: constant_op.constant(1e7) @@ -439,7 +426,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): "cross_replica_loss_not_reduced": distribution.unwrap(distribution.broadcast(initial_loss())) } - ctx = distribution.run_steps_on_dataset( + ctx = distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=2, initial_loop_values=initial_loop_values) @@ -458,7 +445,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): reduced=False, distribution=distribution) return (ctx.run_op, ctx.last_step_outputs["replica_loss_reduced"]) - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) @@ -471,8 +457,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): weights.append(self.evaluate(layer.kernel)) biases.append(self.evaluate(layer.bias)) - self.evaluate(distribution.finalize()) - loss_is_not_increasing = all(y <= x for x, y in zip(losses, losses[1:])) self.assertTrue(loss_is_not_increasing) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 71e50b83b079bc73a7b178356f0f26adbd98638f..5391e083fc9b3ed99cc64bbed11bdeb8dea07f93 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -18,11 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools - from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import mirrored_strategy -from tensorflow.python.distribute import values # pylint: disable=protected-access,invalid-name @@ -48,8 +46,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): distributed environment. There are several important concepts for distributed TensorFlow, e.g. - `client`, `job`, 'task', `cluster`, `in-graph replication` and - 'synchronous training' and they have already been defined in the + `client`, `job`, `task`, `cluster`, `in-graph replication` and + `synchronous training` and they have already been defined in the [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed). The distribution strategy inherits these concepts as well and in addition to that we also clarify several more concepts: @@ -104,6 +102,61 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): auto_shard_dataset) super(MirroredStrategy, self).__init__(extended) + # Override to change the documentation to reflect the different handling of + # global vs. local batch size between core and contrib. + def make_dataset_iterator(self, dataset): # pylint: disable=useless-super-delegation + """Makes an iterator for input provided via `dataset`. + + NOTE: The batch size of the `dataset` argument is treated differently for + this contrib version of `MirroredStrategy`. + + Data from the given dataset will be distributed evenly across all the + compute replicas. We will assume that the input dataset is batched by the + per-replica batch size. + + The user could also use `make_input_fn_iterator` if they want to + customize which input is fed to which replica/worker etc. + + Args: + dataset: `tf.data.Dataset` that will be distributed evenly across all + replicas. + + Returns: + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. + """ + return super(MirroredStrategy, self).make_dataset_iterator(dataset) + + # Override to change the documentation to reflect the different handling of + # global vs. local batch size between core and contrib. + def experimental_make_numpy_iterator( # pylint: disable=useless-super-delegation + self, numpy_input, batch_size, num_epochs=1, shuffle=1024, session=None): + """Makes an iterator for input provided via a nest of numpy arrays. + + NOTE: The `batch_size` argument here has different behavior for this + contrib version of `MirroredStrategy`. + + Args: + numpy_input: A nest of NumPy input arrays that will be distributed evenly + across all replicas. + batch_size: The number of entries from the array we should consume in one + step of the computation, across all replicas. This is the per-replica + batch size. The global batch size will be this times + `num_replicas_in_sync`. + num_epochs: The number of times to iterate through the examples. A value + of `None` means repeat forever. + shuffle: Size of buffer to use for shuffling the input examples. + Use `None` to disable shuffling. + session: (TensorFlow v1.x graph execution only) A session used for + initialization. + + Returns: + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. + """ + return super(MirroredStrategy, self).experimental_make_numpy_iterator( + numpy_input, batch_size, num_epochs, shuffle, session) + class MirroredExtended(CoreMirroredExtended): """Implementation of (contrib) MirroredStrategy.""" @@ -135,19 +188,10 @@ class MirroredExtended(CoreMirroredExtended): Returns: An `InputIterator` which returns inputs for each step of the computation. """ - return values.DatasetIterator(dataset, self._input_workers) - - def _distribute_dataset(self, dataset_fn): - if self._local_mode: - return values.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), self._input_workers, 0) - else: - return values.MultiWorkerDataset( - functools.partial(self._call_dataset_fn, dataset_fn), - self._input_workers, - auto_shard=self._auto_shard_dataset) + return input_lib.DatasetIterator(dataset, self._input_workers) # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): + """The contrib version of Mirrored strategy uses per-replica batch size.""" return False diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index f4becf1d6291cc0c7e2bdbc3911394764412b037..d6337d106fced921b8bda0a2faac99c2a77fab8e 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -66,8 +66,10 @@ GPU_TEST = "test_gpu" in sys.argv[0] combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], mode=["graph", "eager"])) -class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase, - parameterized.TestCase): +class MirroredTwoDeviceDistributionTest( + strategy_test_lib.DistributionTestBase, + strategy_test_lib.TwoDeviceDistributionTestBase, + parameterized.TestCase): def testMinimizeLoss(self, distribution): if context.executing_eagerly(): @@ -114,9 +116,30 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase, self._test_input_fn_iterator(iterator, distribution.extended.worker_devices, expected_values) + def testNumpyIterator(self, distribution): + self._test_numpy_iterator(distribution) + def testGlobalStepUpdate(self, distribution): self._test_global_step_update(distribution) + def testAllReduceSum(self, distribution): + self._test_all_reduce_sum(distribution) + + def testAllReduceSumGradients(self, distribution): + self._test_all_reduce_sum_gradients(distribution) + + def testAllReduceSumGradientTape(self, distribution): + self._test_all_reduce_sum_gradient_tape(distribution) + + def testAllReduceMean(self, distribution): + self._test_all_reduce_mean(distribution) + + def testAllReduceMeanGradients(self, distribution): + self._test_all_reduce_mean_gradients(distribution) + + def testAllReduceMeanGradientTape(self, distribution): + self._test_all_reduce_mean_gradient_tape(distribution) + def one_device_combinations(): return combinations.combine( @@ -128,25 +151,42 @@ def one_device_combinations(): mode=["graph", "eager"]) +@combinations.generate(one_device_combinations()) class MirroredOneDeviceDistributionTest( strategy_test_lib.DistributionTestBase, + strategy_test_lib.OneDeviceDistributionTestBase, parameterized.TestCase): - @combinations.generate(one_device_combinations()) def testMinimizeLoss(self, distribution): if context.executing_eagerly(): self._test_minimize_loss_eager(distribution) else: self._test_minimize_loss_graph(distribution) - @combinations.generate(one_device_combinations()) def testReplicaId(self, distribution): self._test_replica_id(distribution) - @combinations.generate(one_device_combinations()) def testCallAndMergeExceptions(self, distribution): self._test_call_and_merge_exceptions(distribution) + def testAllReduceSum(self, distribution): + self._test_all_reduce_sum(distribution) + + def testAllReduceSumGradients(self, distribution): + self._test_all_reduce_sum_gradients(distribution) + + def testAllReduceSumGradientTape(self, distribution): + self._test_all_reduce_sum_gradient_tape(distribution) + + def testAllReduceMean(self, distribution): + self._test_all_reduce_mean(distribution) + + def testAllReduceMeanGradients(self, distribution): + self._test_all_reduce_mean_gradients(distribution) + + def testAllReduceMeanGradientTape(self, distribution): + self._test_all_reduce_mean_gradient_tape(distribution) + class MirroredStrategyVariableCreatorStackTest( test.TestCase, parameterized.TestCase): @@ -221,11 +261,13 @@ class MirroredStrategyVariableCreationTest(test.TestCase): # TODO(priyag): Modify more tests to use this helper and check more # properties. - def _test_mv_properties(self, var, name): + def _test_mv_properties(self, var, name, strategy): self.assertIsInstance(var, values.MirroredVariable) self.assertEqual(name, var.name) + self.assertIs(strategy, var.distribute_strategy) for d in var.devices: self.assertEqual(d, var.get(d).device) + self.assertIs(strategy, var.get(d)._distribute_strategy) # pylint: disable=protected-access def testVariableInFuncGraph(self, distribution): def model_fn(): @@ -237,8 +279,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): v1 = variable_scope.variable(1.0, name="foo") v2 = distribution.extended.call_for_each_replica(model_fn) - self._test_mv_properties(v1, "foo:0") - self._test_mv_properties(v2, "bar:0") + self._test_mv_properties(v1, "foo:0", distribution) + self._test_mv_properties(v2, "bar:0", distribution) def testSingleVariable(self, distribution): def model_fn(): @@ -251,7 +293,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with distribution.scope(): result = distribution.extended.call_for_each_replica(model_fn) - self._test_mv_properties(result, "foo:0") + self._test_mv_properties(result, "foo:0", distribution) def testUnnamedVariable(self, distribution): def model_fn(): @@ -261,7 +303,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with distribution.scope(): result = distribution.extended.call_for_each_replica(model_fn) - self._test_mv_properties(result, "Variable:0") + self._test_mv_properties(result, "Variable:0", distribution) def testMultipleVariables(self, distribution): def model_fn(): @@ -274,7 +316,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with distribution.scope(): result = distribution.extended.call_for_each_replica(model_fn) for i, v in enumerate(result): - self._test_mv_properties(v, "foo" + str(i) + ":0") + self._test_mv_properties(v, "foo" + str(i) + ":0", distribution) def testMultipleVariablesWithSameCanonicalName(self, distribution): def model_fn(): @@ -324,14 +366,9 @@ class MirroredStrategyVariableCreationTest(test.TestCase): (layer2.kernel, layer2.bias), (layer3.kernel, layer3.bias)] - ds = distribution.distribute_dataset( - lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)) - if context.executing_eagerly(): - iterator = ds.make_one_shot_iterator() - else: - iterator = ds.make_initializable_iterator() - self.evaluate([iterator.initializer]) - + iterator = distribution.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)) + self.evaluate(iterator.initialize()) features = iterator.get_next() with distribution.scope(): @@ -693,6 +730,15 @@ class MirroredStrategyVariableCreationTest(test.TestCase): distribution.extended.worker_devices[0]).read_value())) self.assertEqual(10.0, self.evaluate(ret_v_sum)) + def testVarDistributeStrategy(self, distribution): + with distribution.scope(): + mirrored = variable_scope.variable(1.0) + replica_local = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ) + self.assertIs(distribution, mirrored.distribute_strategy) + self.assertIs(distribution, replica_local.distribute_strategy) + @combinations.generate(combinations.combine( distribution=[ @@ -1215,7 +1261,7 @@ class MirroredStrategyDefunTest(test.TestCase): self.evaluate(device_result)) for defun in defuns: - # PolymorphicFunctions are specialized to the current device stack, so + # `Function`s are specialized to the current device stack, so # call_for_each has one trace per device. To check that the expected set # of variables was accessed on each trace, we first retrieve each # device-specific graph function. diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py index 17b7ab74f63f42e1ee14a82d3bffdd1df9b25857..53e35ea6b75088a3de9866973f872da4a4ce25d6 100644 --- a/tensorflow/contrib/distribute/python/monitor.py +++ b/tensorflow/contrib/distribute/python/monitor.py @@ -51,7 +51,7 @@ class Monitor(object): else: if session is None: raise ValueError("Should provide a `session` in Graph mode.") - session.run(step_callable._iterator.initializer) # pylint: disable=protected-access + session.run(step_callable.initialize()) self._run_step = session.make_callable(step_callable()) session.run(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py index 8f13e9153ea7a951dd722c4549882c97e79b57fe..c4622cdd2af2f6a9c936fe554bcc2eb76f805fdc 100644 --- a/tensorflow/contrib/distribute/python/moving_averages_test.py +++ b/tensorflow/contrib/distribute/python/moving_averages_test.py @@ -53,7 +53,7 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): return var, assign with distribution.scope(), self.cached_session() as sess: - var, assign = distribution.call_for_each_replica(replica_fn) + var, assign = distribution.extended.call_for_each_replica(replica_fn) variables.global_variables_initializer().run() self.assertAllClose([10.0, 11.0], var.eval()) sess.run(distribution.unwrap(assign)) @@ -79,7 +79,7 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): return var, assign.op with distribution.scope(), self.cached_session() as sess: - var, assign_op = distribution.call_for_each_replica(replica_fn) + var, assign_op = distribution.extended.call_for_each_replica(replica_fn) variables.global_variables_initializer().run() self.assertAllClose([0.0, 0.0], var.eval()) sess.run(distribution.unwrap(assign_op)) @@ -152,7 +152,7 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): return var, assign with distribution.scope(), self.cached_session() as sess: - var, assign = distribution.call_for_each_replica(replica_fn) + var, assign = distribution.extended.call_for_each_replica(replica_fn) variables.global_variables_initializer().run() self.assertAllClose([10.0, 11.0], var.eval()) sess.run(distribution.unwrap(assign)) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 5986bc4661f2615a16fcd8d5bf503f1f0dd3d504..24d6a443fe15c9b9ff34b7e6d3a5bc5a2bb7abfb 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -18,10 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six - from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib +from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import values from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -51,41 +51,38 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): super(OneDeviceExtended, self).__init__(container_strategy) self._device = device self._default_device = device - worker = device_util.canonicalize("/device:CPU:0") - worker_device_pairs = [(worker, [self._device])] + self._input_device = device_util.canonicalize("/device:CPU:0") + worker_device_pairs = [(self._input_device, [self._device])] device_map = values.SingleDeviceMap(device) - self._input_workers = values.InputWorkers(device_map, worker_device_pairs) + self._input_workers = input_lib.InputWorkers( + device_map, worker_device_pairs) def _create_variable(self, next_creator, *args, **kwargs): colocate_with = kwargs.pop("colocate_with", None) if colocate_with is None: with ops.device(self._device): return next_creator(*args, **kwargs) - if isinstance(colocate_with, six.string_types): - with ops.device(colocate_with): - return next_creator(*args, **kwargs) - if (isinstance(colocate_with, (list, tuple)) and len(colocate_with) == 1 and - isinstance(colocate_with[0], six.string_types)): - with ops.device(colocate_with[0]): - return next_creator(*args, **kwargs) with ops.colocate_with(colocate_with): return next_creator(*args, **kwargs) + def _validate_colocate_with_variable(self, colocate_with_variable): + values.validate_colocate(colocate_with_variable, self) + def _make_dataset_iterator(self, dataset): """Make iterator from dataset without splitting the batch.""" - return values.DatasetIterator(dataset, self._input_workers) - - def _distribute_dataset(self, dataset_fn): - return values.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), self._input_workers, 0) + return input_lib.DatasetIterator(dataset, self._input_workers) def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): - return values.InputFunctionIterator( + return input_lib.InputFunctionIterator( input_fn, self._input_workers, [distribute_lib.InputContext()]) + def _experimental_make_numpy_dataset(self, numpy_input, session): + return numpy_dataset.one_host_numpy_dataset( + numpy_input, self._input_device, session) + def _broadcast_to(self, tensor, destinations): del destinations return tensor @@ -97,7 +94,7 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) - ctx = values.MultiStepContext() + ctx = input_lib.MultiStepContext() def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args @@ -198,6 +195,7 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): + """Global and per-replica batching are equivalent for OneDeviceStrategy.""" return True diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index d46cd6f529e363f76bfa2b22339add63530cfde8..f81466a6c75f1cf287cdb00917872f77383c615e 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -25,7 +25,9 @@ from tensorflow.python.eager import test from tensorflow.python.framework import test_util -class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): +class OneDeviceStrategyTest( + strategy_test_lib.DistributionTestBase, + strategy_test_lib.OneDeviceDistributionTestBase): def _get_distribution_strategy(self): return one_device_strategy.OneDeviceStrategy("/device:CPU:0") @@ -57,6 +59,28 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): self._test_input_fn_iterator( iterator, d.extended.worker_devices, expected_values) + @test_util.run_in_graph_and_eager_modes + def testNumpyIterator(self): + self._test_numpy_iterator(self._get_distribution_strategy()) + + def testAllReduceSum(self): + self._test_all_reduce_sum(self._get_distribution_strategy()) + + def testAllReduceSumGradients(self): + self._test_all_reduce_sum_gradients(self._get_distribution_strategy()) + + def testAllReduceSumGradientTape(self): + self._test_all_reduce_sum_gradient_tape(self._get_distribution_strategy()) + + def testAllReduceMean(self): + self._test_all_reduce_mean(self._get_distribution_strategy()) + + def testAllReduceMeanGradients(self): + self._test_all_reduce_mean_gradients(self._get_distribution_strategy()) + + def testAllReduceMeanGradientTape(self): + self._test_all_reduce_mean_gradient_tape(self._get_distribution_strategy()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index fa4705af7cb592119f56686d1f693a156f7b4b13..e388061b17a9b92dedbbf9839049b13c8575a22c 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -41,21 +41,17 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): with distribution.scope(): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - - ds = distribution.distribute_dataset(dataset_fn) - if context.executing_eagerly(): - iterator = ds.make_one_shot_iterator() - else: - iterator = ds.make_initializable_iterator() + iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn()) def run_step(): - return control_flow_ops.group(distribution.unwrap( - distribution.call_for_each_replica( - model_fn, args=(iterator.get_next(),)))) + return control_flow_ops.group( + distribution.unwrap( + distribution.extended.call_for_each_replica( + model_fn, args=(iterator.get_next(),)))) if not context.executing_eagerly(): with self.cached_session() as sess: - sess.run(iterator.initializer) + sess.run(iterator.initialize()) run_step = sess.make_callable(run_step()) self.evaluate(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 5029d59641a25364d02874bd945af15147debc24..e42bc50fdc4e5e93c998708b0790fdea7768faf2 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -18,34 +18,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy - -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib -from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib -from tensorflow.python.distribute import multi_worker_util -from tensorflow.python.distribute import values -from tensorflow.python.eager import context -from tensorflow.python.framework import device as tf_device -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import device_setter -from tensorflow.python.util import nest +from tensorflow.python.distribute import input_lib +from tensorflow.python.distribute import parameter_server_strategy +from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver + +# pylint: disable=protected-access,invalid-name,line-too-long +CoreParameterServerStrategy = parameter_server_strategy.ParameterServerStrategy +CoreParameterServerExtended = parameter_server_strategy.ParameterServerStrategyExtended -_LOCAL_CPU = "/device:CPU:0" -_LOCAL_GPU_0 = "/device:GPU:0" +# pylint: enable=protected-access,invalid-name,line-too-long -# TODO(yuefengz): maybe cache variables on local CPU. -# TODO(yuefengz): we may want to set session options to disallow communication -# between workers. class ParameterServerStrategy(distribute_lib.DistributionStrategy): """A parameter server DistributionStrategy. + *** contrib version *** + This strategy class works for both local training and between-graph replicated training for multiple workers. If `cluster_spec` is specified, either passed in to __init__() method or parsed from the @@ -80,9 +70,9 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): variables. 3) It is also not recommended to open a colocation scope (i.e. calling - `tf.colocate_with`) under the strategy's scope. For colocating variables, - use `distribution.colocate_vars_with` instead. Colocation of ops will possibly - create conflicts of device assignment. + `tf.colocate_with`) under the strategy's scope. For colocating variables, use + `strategy.extended.colocate_vars_with` instead. Colocation of ops will + possibly create conflicts of device assignment. """ def __init__(self, num_gpus_per_worker=0): @@ -99,433 +89,84 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): super(ParameterServerStrategy, self).__init__( ParameterServerExtended(self, num_gpus_per_worker)) + # Override to change the documentation to reflect the different handling of + # global vs. local batch size between core and contrib. + def make_dataset_iterator(self, dataset): # pylint: disable=useless-super-delegation + """Makes an iterator for input provided via `dataset`. -class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): - """Implementation of ParameterServerStrategy.""" + NOTE: The batch size of the `dataset` argument is treated differently for + this contrib version of `ParameterServerStrategy`. - def __init__(self, container_strategy, num_gpus_per_worker): - super(ParameterServerExtended, self).__init__(container_strategy) - self._num_gpus_per_worker = num_gpus_per_worker - self._initialize_local(num_gpus_per_worker) + Data from the given dataset will be distributed evenly across all the + compute replicas. We will assume that the input dataset is batched by the + per-replica batch size. - # We typically don't need to do all-reduce in this strategy. - self._cross_device_ops = ( - cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( - reduce_to_device=_LOCAL_CPU)) - - def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec, - task_type, task_id): - """Initialize devices for multiple workers. - - It creates variable devices and compute devices. Variables and operations - will be assigned to them respectively. We have one compute device per - replica. The variable device is a device function or device string. The - default variable device assigns variables to parameter servers in a - round-robin fashion. + The user could also use `make_input_fn_iterator` if they want to + customize which input is fed to which replica/worker etc. Args: - num_gpus_per_worker: number of local GPUs or GPUs per worker. - cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the - cluster configurations. - task_type: the current task type. - task_id: the current task id. + dataset: `tf.data.Dataset` that will be distributed evenly across all + replicas. - Raises: - ValueError: if the cluster_spec doesn't have ps jobs. + Returns: + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. """ - assert cluster_spec - if not task_type or task_id is None: - raise ValueError("When `cluster_spec` is given, you must also specify " - "`task_type` and `task_id`") - cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) - - worker_device = "/job:%s/task:%d" % (self._task_type, self._task_id) - - # Define compute devices which is a list of device strings and one for each - # replica. When there are GPUs, replicate operations on these GPUs. - # Otherwise, place operations on CPU. - if num_gpus_per_worker > 0: - compute_devices = tuple( - "%s/device:GPU:%d" % (worker_device, i) - for i in range(num_gpus_per_worker) - ) - else: - compute_devices = (worker_device,) - - self._device_map = values.ReplicaDeviceMap(compute_devices) - self._input_workers = values.InputWorkers( - self._device_map, [(worker_device, compute_devices)]) - - # In distributed mode, place variables on ps jobs in a round-robin fashion. - # Note that devices returned from `replica_device_setter` are not - # canonical and therefore we don't canonicalize all variable devices to - # make them consistent. - # TODO(yuefengz): support passing a strategy object to control variable - # assignment. - # TODO(yuefengz): merge the logic of replica_device_setter into this - # class. - num_ps_replicas = len(cluster_spec.as_dict().get("ps", [])) - if num_ps_replicas == 0: - raise ValueError("The cluster spec needs to have `ps` jobs.") - self._variable_device = device_setter.replica_device_setter( - ps_tasks=num_ps_replicas, - worker_device=worker_device, - merge_devices=True, - cluster=cluster_spec) - - # The `_parameter_devices` is needed for the `parameter_devices` property - # and is a list of all variable devices. Here parameter devices are all - # tasks of the "ps" job. - self._parameter_devices = tuple(map("/job:ps/task:{}".format, - range(num_ps_replicas))) - - # Add a default device so that ops without specified devices will not end up - # on other workers. - self._default_device = worker_device - - self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, - task_id) - self._cluster_spec = cluster_spec - self._task_type = task_type - self._task_id = task_id - - logging.info( - "Multi-worker ParameterServerStrategy with " - "cluster_spec = %r, task_type = %r, task_id = %r, " - "num_ps_replicas = %r, is_chief = %r, device_map = %r, " - "variable_device = %r", cluster_spec.as_dict(), task_type, task_id, - num_ps_replicas, self._is_chief, self._device_map, - self._variable_device) - - def _initialize_local(self, num_gpus_per_worker): - """Initialize internal devices for local training.""" - worker_device = device_util.canonicalize("/device:CPU:0") - # Define compute devices which is a list of device strings and one for each - # replica. When there are GPUs, replicate operations on these GPUs. - # Otherwise, place operations on CPU. - if num_gpus_per_worker > 0: - compute_devices = tuple( - map("/device:GPU:{}".format, range(num_gpus_per_worker))) - else: - compute_devices = (_LOCAL_CPU,) - - self._device_map = values.ReplicaDeviceMap(compute_devices) - self._input_workers = values.InputWorkers( - self._device_map, [(worker_device, compute_devices)]) - - # If there is only one GPU, put everything on that GPU. Otherwise, place - # variables on CPU. - if num_gpus_per_worker == 1: - assert len(compute_devices) == 1 - self._variable_device = _LOCAL_GPU_0 - self._parameter_devices = (_LOCAL_GPU_0,) - else: - self._variable_device = _LOCAL_CPU - self._parameter_devices = (_LOCAL_CPU,) - - self._is_chief = True - self._cluster_spec = None - self._task_type = None - self._task_id = None - - logging.info( - "ParameterServerStrategy with compute_devices = %r, " - "variable_device = %r", compute_devices, self._variable_device) - - def _distribute_dataset(self, dataset_fn): - """Distributes the dataset to each local GPU.""" - return values.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), self._input_workers, 0, - prefetch_on_device=True) - - def _make_dataset_iterator(self, dataset): - return values.DatasetIterator(dataset, self._input_workers, - self._num_replicas_in_sync) - - def _make_input_fn_iterator( - self, - input_fn, - replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): - """Distributes the dataset to each local GPU.""" - if self._cluster_spec: - input_pipeline_id = multi_worker_util.id_in_cluster( - self._cluster_spec, self._task_type, self._task_id) - num_input_pipelines = multi_worker_util.worker_count( - self._cluster_spec, self._task_type) - else: - input_pipeline_id = 0 - num_input_pipelines = 1 - input_context = distribute_lib.InputContext( - num_input_pipelines=num_input_pipelines, - input_pipeline_id=input_pipeline_id, - num_replicas_in_sync=self._num_replicas_in_sync) - return values.InputFunctionIterator( - input_fn, self._input_workers, [input_context]) - - def _broadcast_to(self, tensor, destinations): - # This is both a fast path for Python constants, and a way to delay - # converting Python values to a tensor until we know what type it - # should be converted to. Otherwise we have trouble with: - # global_step.assign_add(1) - # since the `1` gets broadcast as an int32 but global_step is int64. - if isinstance(tensor, (float, int)): - return tensor - if not cross_device_ops_lib.check_destinations(destinations): - # TODO(josh11b): Use current logical device instead of 0 here. - destinations = values.LogicalDeviceSpec( - device_map=self._device_map, logical_device=0) - return self._cross_device_ops.broadcast(tensor, destinations) - - def _allow_variable_partition(self): - return not context.executing_eagerly() - - # TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through - # this creator, such as "MutableHashTable". - def _create_variable(self, next_creator, *args, **kwargs): - if self._num_replicas_in_sync > 1: - aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) - if aggregation not in ( - vs.VariableAggregation.NONE, - vs.VariableAggregation.SUM, - vs.VariableAggregation.MEAN, - vs.VariableAggregation.ONLY_FIRST_REPLICA - ): - raise ValueError("Invalid variable aggregation mode: " + aggregation + - " for variable: " + kwargs["name"]) - - def var_creator(*args, **kwargs): - """Create an AggregatingVariable and fix up collections.""" - # Record what collections this variable should be added to. - collections = kwargs.pop("collections", None) - if collections is None: - collections = [ops.GraphKeys.GLOBAL_VARIABLES] - kwargs["collections"] = [] - - # Create and wrap the variable. - v = next_creator(*args, **kwargs) - wrapped = values.AggregatingVariable(v, aggregation) - - # Add the wrapped variable to the requested collections. - # The handling of eager mode and the global step matches - # ResourceVariable._init_from_args(). - if not context.executing_eagerly(): - g = ops.get_default_graph() - # If "trainable" is True, next_creator() will add the contained - # variable to the TRAINABLE_VARIABLES collection, so we manually - # remove it and replace with the wrapper. We can't set "trainable" - # to False for next_creator() since that causes functions like - # implicit_gradients to skip those variables. - if kwargs.get("trainable", True): - collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) - l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) - l.remove(v) - g.add_to_collections(collections, wrapped) - elif ops.GraphKeys.GLOBAL_STEP in collections: - ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped) - - return wrapped - else: - var_creator = next_creator - - if "colocate_with" in kwargs: - with ops.device(None): - with ops.colocate_with(kwargs["colocate_with"]): - return var_creator(*args, **kwargs) - - with ops.colocate_with(None, ignore_existing=True): - with ops.device(self._variable_device): - return var_creator(*args, **kwargs) - - def _call_for_each_replica(self, fn, args, kwargs): - # pylint: disable=protected-access - return mirrored_strategy._call_for_each_replica( - self._container_strategy(), self._device_map, fn, args, kwargs) - - def _verify_destinations_not_different_worker(self, destinations): - if not self._cluster_spec: - return - if destinations is None: - return - for d in cross_device_ops_lib.get_devices_from(destinations): - d_spec = tf_device.DeviceSpec.from_string(d) - if d_spec.job == self._task_type and d_spec.task != self._task_id: - raise ValueError( - "Cannot reduce to another worker: %r, current worker is %r" % - (d, self._input_workers.worker_devices[0])) + return super(ParameterServerStrategy, self).make_dataset_iterator(dataset) - def _reduce_to(self, reduce_op, value, destinations): - self._verify_destinations_not_different_worker(destinations) - if not isinstance(value, values.DistributedValues): - # pylint: disable=protected-access - return cross_device_ops_lib.reduce_non_distributed_value( - reduce_op, self._device_map, value, destinations) - return self._cross_device_ops.reduce( - reduce_op, value, destinations=destinations) + # Override to change the documentation to reflect the different handling of + # global vs. local batch size between core and contrib. + def experimental_make_numpy_iterator( # pylint: disable=useless-super-delegation + self, numpy_input, batch_size, num_epochs=1, shuffle=1024, session=None): + """Makes an iterator for input provided via a nest of numpy arrays. - def _batch_reduce_to(self, reduce_op, value_destination_pairs): - for _, destinations in value_destination_pairs: - self._verify_destinations_not_different_worker(destinations) - return self._cross_device_ops.batch_reduce(reduce_op, - value_destination_pairs) - - def _select_single_value(self, structured): - """Select any single values in `structured`.""" - - def _select_fn(x): # pylint: disable=g-missing-docstring - if isinstance(x, values.Mirrored): - if len(x.devices) == 1: - return x.primary - else: - raise ValueError( - "You cannot update variable with a Mirrored object with multiple " - "components %r when using ParameterServerStrategy. You must " - "specify a single value or a Mirrored with a single value." % x) - elif isinstance(x, values.PerReplica): - raise ValueError( - "You cannot update variable with a PerReplica object %r when using " - "ParameterServerStrategy. You must specify a single value or a " - "Mirrored with a single value" % x) - else: - return x - - return nest.map_structure(_select_fn, structured) - - def _update(self, var, fn, args, kwargs, group): - if isinstance(var, values.AggregatingVariable): - var = var.get() - if not isinstance(var, resource_variable_ops.ResourceVariable): - raise ValueError( - "You can not update `var` %r. It must be a Variable." % var) - with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): - result = fn(var, *self._select_single_value(args), - **self._select_single_value(kwargs)) - if group: - return result - else: - return nest.map_structure(self._unwrap, result) - - # TODO(yuefengz): does it need to call _select_single_value? - def _update_non_slot(self, colocate_with, fn, args, kwargs, group): - with ops.device( - colocate_with.device), distribute_lib.UpdateContext(colocate_with): - result = fn(*args, **kwargs) - if group: - return result - else: - return nest.map_structure(self._unwrap, result) - - def _unwrap(self, val): - if isinstance(val, values.DistributedValues): - return val.values - return (val,) - - def value_container(self, val): - if (hasattr(val, "_aggregating_container") and - not isinstance(val, values.AggregatingVariable)): - wrapper = val._aggregating_container() # pylint: disable=protected-access - if wrapper is not None: - return wrapper - return val - - def read_var(self, var): - # No need to distinguish between normal variables and replica-local - # variables. - return array_ops.identity(var) - - def _configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): - """Configures the strategy class. - - The strategy object will be re-initialized if `cluster_spec` is given but - was not passed in the constructor. + NOTE: The `batch_size` argument here has different behavior for this + contrib version of `ParameterServerStrategy`. Args: - session_config: not used currently. - cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the - cluster configurations. - task_type: the current task type. - task_id: the current task id. - - Raises: - ValueError: if `cluster_spec` is given but `task_type` or `task_id` is - not. + numpy_input: A nest of NumPy input arrays that will be distributed evenly + across all replicas. + batch_size: The number of entries from the array we should consume in one + step of the computation, across all replicas. This is the per-replica + batch size. The global batch size will be this times + `num_replicas_in_sync`. + num_epochs: The number of times to iterate through the examples. A value + of `None` means repeat forever. + shuffle: Size of buffer to use for shuffling the input examples. + Use `None` to disable shuffling. + session: (TensorFlow v1.x graph execution only) A session used for + initialization. + + Returns: + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. """ - if not self._cluster_spec and cluster_spec: - # If a `cluster_spec` is already passed in, do nothing here. - # TODO(yuefengz): check `cluster_spec` is the same if this object has - # already been initialized with a `cluster_spec`. - if task_type is None or task_id is None: - raise ValueError("When `cluster_spec` is given, must also specify " - "`task_type` and `task_id`.") - self._cluster_spec = multi_worker_util.normalize_cluster_spec( - cluster_spec) - self._task_type = task_type - self._task_id = task_id - self._initialize_multi_worker(self._num_gpus_per_worker, - self._cluster_spec, task_type, task_id) - - if session_config: - session_config.CopyFrom(self._update_config_proto(session_config)) - - def _update_config_proto(self, config_proto): - updated_config = copy.deepcopy(config_proto) - if not self._cluster_spec: - updated_config.isolate_session_state = True - return updated_config - - updated_config.isolate_session_state = False - - assert self._task_type - assert self._task_id is not None + return super(ParameterServerStrategy, + self).experimental_make_numpy_iterator( + numpy_input, batch_size, num_epochs, shuffle, session) - # The device filters prevent communication between workers. - if self._task_type not in ["chief", "worker"]: - return updated_config - del updated_config.device_filters[:] - updated_config.device_filters.extend( - ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"]) - return updated_config - @property - def _num_replicas_in_sync(self): - return self._device_map.num_replicas_in_graph - - @property - def worker_devices(self): - return self._device_map.all_devices - - @property - def worker_devices_by_replica(self): - return self._device_map.devices_by_replica - - @property - def parameter_devices(self): - return self._parameter_devices - - def non_slot_devices(self, var_list): - return min(var_list, key=lambda x: x.name) - - @property - def experimental_between_graph(self): - # TODO(yuefengz): Should this return False in the local case? - return True - - @property - def experimental_should_init(self): - return self._is_chief +class ParameterServerExtended(CoreParameterServerExtended): + """Implementation of ParameterServerStrategy.""" - @property - def should_checkpoint(self): - return self._is_chief + def __init__(self, container_strategy, num_gpus_per_worker): + # Use TFConfigClusterResolver to parse TF_CONFIG. We don't want to change + # the constructor's interface to allow customized cluster resolver. Use + # SimpleClusterResolver to override num_accelerators. + tfconfig = TFConfigClusterResolver() + cluster_resolver = SimpleClusterResolver( + cluster_spec=tfconfig.cluster_spec(), + task_type=tfconfig.task_type, + task_id=tfconfig.task_id, + num_accelerators=num_gpus_per_worker) + super(ParameterServerExtended, self).__init__( + container_strategy, cluster_resolver=cluster_resolver) - @property - def should_save_summary(self): - return self._is_chief + def _make_dataset_iterator(self, dataset): + return input_lib.DatasetIterator(dataset, self._input_workers) # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): + """The contrib version of PS strategy uses per-replica batch size.""" return False diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index 805c643e679338467bf576e17baa8bf839f3b292..89dcdbcfc2f1f9d8cd46db9ccf133be08ff89533 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -29,10 +29,13 @@ from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import parameter_server_strategy as core_parameter_server_strategy from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values +from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.estimator import run_config @@ -45,10 +48,12 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import training_util +from tensorflow.python.training.server_lib import ClusterSpec CHIEF = run_config.TaskType.CHIEF WORKER = run_config.TaskType.WORKER @@ -62,6 +67,57 @@ def _get_replica_id_integer(): return replica_id +class MockCoreParameterServerStrategy(distribute_lib.DistributionStrategy): + """Mock the strategy to allow cluster resolver as an argument.""" + + def __init__(self, cluster_resolver): + super(MockCoreParameterServerStrategy, self).__init__( + core_parameter_server_strategy.ParameterServerStrategyExtended( + self, cluster_resolver=cluster_resolver)) + + +def create_test_objects(cluster_spec=None, + task_type=None, + task_id=None, + num_gpus=None, + sess_config=None, + use_core_strategy=False): + sess_config = sess_config or config_pb2.ConfigProto() + if num_gpus is None: + num_gpus = context.num_gpus() + if use_core_strategy: + if cluster_spec and task_type and task_id is not None: + cluster_resolver = SimpleClusterResolver( + cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), + task_type=task_type, + task_id=task_id, + num_accelerators=num_gpus) + target = 'grpc://' + cluster_spec[WORKER][task_id] + else: + cluster_resolver = SimpleClusterResolver( + ClusterSpec({}), num_accelerators=num_gpus) + target = '' + + distribution = MockCoreParameterServerStrategy(cluster_resolver) + sess_config = copy.deepcopy(sess_config) + sess_config = distribution.update_config_proto(sess_config) + else: + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=num_gpus) + if task_type: + sess_config = copy.deepcopy(sess_config) + distribution.configure( + session_config=sess_config, + cluster_spec=cluster_spec, + task_type=task_type, + task_id=task_id) + target = 'grpc://' + cluster_spec[WORKER][task_id] + else: + target = '' + + return distribution, target, sess_config + + class ParameterServerStrategyTestBase( multi_worker_test_base.MultiWorkerTestBase): @@ -75,24 +131,27 @@ class ParameterServerStrategyTestBase( self._sess_config = config_pb2.ConfigProto(allow_soft_placement=True) super(ParameterServerStrategyTestBase, self).setUp() - def _get_test_objects(self, task_type, task_id, num_gpus): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=num_gpus) - if not task_type: - return distribution, '', self._sess_config - - sess_config = copy.deepcopy(self._sess_config) - distribution.configure( - session_config=sess_config, + def _get_test_objects(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): + return create_test_objects( cluster_spec=self._cluster_spec, task_type=task_type, - task_id=task_id) - return (distribution, 'grpc://' + self._cluster_spec[WORKER][task_id], - sess_config) - - def _test_device_assignment_distributed(self, task_type, task_id, num_gpus): + task_id=task_id, + num_gpus=num_gpus, + sess_config=self._sess_config, + use_core_strategy=use_core_strategy) + + def _test_device_assignment_distributed(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id) - d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus) + d, _, sess_config = self._get_test_objects( + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) with ops.Graph().as_default(), \ self.cached_session(target=self._default_target, config=sess_config) as sess, \ @@ -131,7 +190,7 @@ class ParameterServerStrategyTestBase( '/job:worker/replica:0/task:0/%s' % last_part_device) # The colocate_vars_with can override the distribution's device. - with d.colocate_vars_with(x): + with d.extended.colocate_vars_with(x): y = variable_scope.get_variable( 'y', initializer=20.0, aggregation=variable_scope.VariableAggregation.SUM) @@ -177,7 +236,7 @@ class ParameterServerStrategyTestBase( self.assertIn('/job:ps/', h.device) return y_add, z_add, f - y, z, f = d.call_for_each_replica(model_fn) + y, z, f = d.extended.call_for_each_replica(model_fn) self.assertNotEqual(y, None) self.assertNotEqual(z, None) self.assertNotEqual(f, None) @@ -190,9 +249,10 @@ class ParameterServerStrategyTestBase( self.assertEqual(f_val, 46.0) def _test_device_assignment_distributed_enable_partitioner( - self, task_type, task_id, num_gpus): - d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus) - num_shards = len(d.parameter_devices) + self, task_type, task_id, num_gpus, use_core_strategy=False): + d, _, sess_config = self._get_test_objects( + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) + num_shards = len(d.extended.parameter_devices) partitioner = partitioned_variables.fixed_size_partitioner(num_shards) with ops.Graph().as_default(), \ self.cached_session(target=self._default_target, @@ -224,39 +284,18 @@ class ParameterServerStrategyTestBase( self.assertEqual(var.device, '/job:ps/task:%d' % part_id) self.assertEqual(var.device, x_add[part_id].device) - # The colocate_vars_with can override the distribution's device. - with d.colocate_vars_with(x_add[0]): - y = variable_scope.get_variable( - 'y', - initializer=constant_op.constant([20.0, 10.0]), - aggregation=variable_scope.VariableAggregation.SUM, - partitioner=partitioner) - y_add = y.assign_add( - [array_ops.identity(x_add[0]), - array_ops.identity(x_add[1])]) - - for part_id, var in enumerate(y): - self.assertEqual(var.device, '/job:ps/task:0') - self.assertEqual(y_add[part_id].device, var.device) - self.assertEqual(var.device, x_add[0].device) + return x_add - return x_add, y_add - - x, y = d.call_for_each_replica(model_fn) + x = d.extended.call_for_each_replica(model_fn) if context.num_gpus() >= 1: variables.global_variables_initializer().run() - x_val, y_val = sess.run([x, y]) + x_val = sess.run(x) if num_gpus < 1: self.assertEqual(x_val, [13.0, 25.0]) - self.assertEqual(y_val, [33.0, 35.0]) else: x_expect = [10.0 + 3 * num_gpus, 20.0 + 5 * num_gpus] - y_expect = [ - 20.0 + x_expect[0] * num_gpus, 10.0 + x_expect[1] * num_gpus - ] self.assertEqual(x_val, x_expect) - self.assertEqual(y_val, y_expect) def _test_device_assignment_local(self, d, @@ -305,7 +344,7 @@ class ParameterServerStrategyTestBase( self.assertEqual(e.device, device_util.canonicalize('/device:GPU:2')) # The colocate_vars_with can override the distribution's device. - with d.colocate_vars_with(x): + with d.extended.colocate_vars_with(x): y = variable_scope.get_variable( 'y', initializer=20.0, aggregation=variable_scope.VariableAggregation.SUM) @@ -348,7 +387,7 @@ class ParameterServerStrategyTestBase( device_util.canonicalize(h.device)) return y_add, z_add, f - y, z, f = d.call_for_each_replica(model_fn) + y, z, f = d.extended.call_for_each_replica(model_fn) self.assertNotEqual(y, None) self.assertNotEqual(z, None) self.assertNotEqual(f, None) @@ -360,9 +399,13 @@ class ParameterServerStrategyTestBase( self.assertEqual(z_val, 43.0) self.assertEqual(f_val, 46.0) - def _test_simple_increment(self, task_type, task_id, num_gpus): + def _test_simple_increment(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): d, master_target, sess_config = self._get_test_objects( - task_type, task_id, num_gpus) + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) if d.extended._cluster_spec: num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER)) if 'chief' in d.extended._cluster_spec.as_dict(): @@ -395,7 +438,7 @@ class ParameterServerStrategyTestBase( train_op = control_flow_ops.group(x_add, y_add, z_add) return x, y, z, train_op - x, y, z, train_op = d.call_for_each_replica(model_fn) + x, y, z, train_op = d.extended.call_for_each_replica(model_fn) train_op = d.group(train_op) if context.num_gpus() < d.extended._num_gpus_per_worker: @@ -430,9 +473,13 @@ class ParameterServerStrategyTestBase( y_val == 20.0 + 1.0 * num_workers * d.num_replicas_in_sync and z_val == 30.0 + 1.0 * num_workers) - def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): + def _test_minimize_loss_graph(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): d, master_target, sess_config = self._get_test_objects( - task_type, task_id, num_gpus) + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) if task_type: # Multi-worker assert hasattr(d.extended, '_cluster_spec') and d.extended._cluster_spec @@ -472,7 +519,7 @@ class ParameterServerStrategyTestBase( def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, args=(one,)) + g_v = d.extended.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] @@ -484,7 +531,7 @@ class ParameterServerStrategyTestBase( g = d.extended.reduce_to( reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( - d.update(v, update, g, grouped=False)): + d.extended.update(v, update, args=(g,), group=False)): after_list.append(d.extended.read_var(v)) return before_list, after_list @@ -518,10 +565,15 @@ class ParameterServerStrategyTestBase( self.assertLess(error_after, error_before) return error_after < error_before - def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn, - expected_values): + def _test_input_fn_iterator(self, + task_type, + task_id, + num_gpus, + input_fn, + expected_values, + use_core_strategy=False): distribution, master_target, config = self._get_test_objects( - task_type, task_id, num_gpus) + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) devices = distribution.extended.worker_devices with ops.Graph().as_default(), \ @@ -551,9 +603,11 @@ class ParameterServerStrategyTestBase( self.assertEqual(expected_value, computed_value) -class ParameterServerStrategyTest(ParameterServerStrategyTestBase, - strategy_test_lib.DistributionTestBase, - parameterized.TestCase): +class ParameterServerStrategyTest( + ParameterServerStrategyTestBase, + strategy_test_lib.DistributionTestBase, + strategy_test_lib.TwoDeviceDistributionTestBase, + parameterized.TestCase): @classmethod def setUpClass(cls): @@ -561,66 +615,93 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, num_workers=3, num_ps=2) cls._default_target = 'grpc://' + cls._cluster_spec[WORKER][0] - def test_num_replicas_in_sync(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def test_num_replicas_in_sync(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) # All the devices on a given worker are in sync which in this case is the # number of gpus on each worker. - self.assertEqual(2, distribution.num_replicas_in_sync) + self.assertEqual(2, strategy.num_replicas_in_sync) - def testDeviceAssignmentLocalCPU(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=0) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testDeviceAssignmentLocalCPU(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=0, use_core_strategy=use_core_strategy) self._test_device_assignment_local( - distribution, compute_device='CPU', variable_device='CPU', num_gpus=0) + strategy, compute_device='CPU', variable_device='CPU', num_gpus=0) - def testDeviceAssignmentLocalOneGPU(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=1) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testDeviceAssignmentLocalOneGPU(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=1, use_core_strategy=use_core_strategy) self._test_device_assignment_local( - distribution, compute_device='GPU', variable_device='GPU', num_gpus=1) + strategy, compute_device='GPU', variable_device='GPU', num_gpus=1) - def testDeviceAssignmentLocalTwoGPUs(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testDeviceAssignmentLocalTwoGPUs(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) self._test_device_assignment_local( - distribution, compute_device='GPU', variable_device='CPU', num_gpus=2) + strategy, compute_device='GPU', variable_device='CPU', num_gpus=2) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testDeviceAssignmentDistributed(self, num_gpus): - self._test_device_assignment_distributed('worker', 1, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testDeviceAssignmentDistributed(self, num_gpus, use_core_strategy): + self._test_device_assignment_distributed( + 'worker', 1, num_gpus, use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testDeviceAssignmentDistributedEnablePartitioner(self, num_gpus): + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testDeviceAssignmentDistributedEnablePartitioner(self, num_gpus, + use_core_strategy): self._test_device_assignment_distributed_enable_partitioner( - 'worker', 1, num_gpus) + 'worker', 1, num_gpus, use_core_strategy=use_core_strategy) - def testSimpleBetweenGraph(self): - self._run_between_graph_clients(self._test_simple_increment, - self._cluster_spec, context.num_gpus()) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testSimpleBetweenGraph(self, use_core_strategy): + self._run_between_graph_clients( + self._test_simple_increment, + self._cluster_spec, + context.num_gpus(), + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testLocalSimpleIncrement(self, num_gpus): - self._test_simple_increment(None, 0, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testLocalSimpleIncrement(self, num_gpus, use_core_strategy): + self._test_simple_increment(None, 0, num_gpus, use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testMinimizeLossGraphDistributed(self, num_gpus): - self._run_between_graph_clients(self._test_minimize_loss_graph, - self._cluster_spec, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testMinimizeLossGraphDistributed(self, num_gpus, use_core_strategy): + self._run_between_graph_clients( + self._test_minimize_loss_graph, + self._cluster_spec, + num_gpus, + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testMinimizeLossGraphLocal(self, num_gpus): - self._test_minimize_loss_graph(None, None, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testMinimizeLossGraphLocal(self, num_gpus, use_core_strategy): + self._test_minimize_loss_graph(None, None, num_gpus, use_core_strategy) # TODO(priyag): Refactor this and other multi worker tests. @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[1, 2], required_gpus=1)) - def testMakeInputFnIteratorDistributed(self, num_gpus): + combinations.combine( + mode=['graph'], + num_gpus=[1, 2], + required_gpus=1, + use_core_strategy=[True, False])) + def testMakeInputFnIteratorDistributed(self, num_gpus, use_core_strategy): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') dataset_fn = lambda: dataset_ops.Dataset.range(100) @@ -632,12 +713,21 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, expected_num_replicas_in_sync=num_gpus, expected_num_input_pipelines=3, expected_input_pipeline_id=1) # because task_id = 1 - self._test_input_fn_iterator('worker', 1, num_gpus, - input_fn, expected_values) + self._test_input_fn_iterator( + 'worker', + 1, + num_gpus, + input_fn, + expected_values, + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[1, 2], required_gpus=1)) - def testMakeInputFnIteratorLocal(self, num_gpus): + combinations.combine( + mode=['graph'], + num_gpus=[1, 2], + required_gpus=1, + use_core_strategy=[True, False])) + def testMakeInputFnIteratorLocal(self, num_gpus, use_core_strategy): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') dataset_fn = lambda: dataset_ops.Dataset.range(100) @@ -649,23 +739,31 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, expected_num_replicas_in_sync=num_gpus, expected_num_input_pipelines=1, expected_input_pipeline_id=0) # only one worker and pipeline for local. - self._test_input_fn_iterator(None, None, num_gpus, - input_fn, expected_values) + self._test_input_fn_iterator( + None, + None, + num_gpus, + input_fn, + expected_values, + use_core_strategy=use_core_strategy) - def testGlobalStepUpdate(self): - strategy = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=context.num_gpus()) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testGlobalStepUpdate(self, use_core_strategy): + strategy, _, _ = create_test_objects(use_core_strategy=use_core_strategy) self._test_global_step_update(strategy) - def testUpdateConfigProtoMultiWorker(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) - distribution.configure( + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testUpdateConfigProtoMultiWorker(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) + strategy.configure( cluster_spec=self._cluster_spec, task_type='worker', task_id=1) config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) - new_config = distribution.update_config_proto(config_proto) + new_config = strategy.update_config_proto(config_proto) # Verify device filters. self.assertEqual(['/job:worker/task:1', '/job:ps'], @@ -674,16 +772,48 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, # Verify isolate_session_state self.assertFalse(new_config.isolate_session_state) - def testUpdateConfigProtoLocal(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testUpdateConfigProtoLocal(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) config_proto = config_pb2.ConfigProto() - new_config = distribution.update_config_proto(config_proto) + new_config = strategy.update_config_proto(config_proto) # Verify isolate_session_state self.assertTrue(new_config.isolate_session_state) + def testAllReduceSum(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_sum(distribution) + + def testAllReduceSumGradients(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_sum_gradients(distribution) + + def testAllReduceSumGradientTape(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_sum_gradient_tape(distribution) + + def testAllReduceMean(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_mean(distribution) + + def testAllReduceMeanGradients(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_mean_gradients(distribution) + + def testAllReduceMeanGradientTape(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_mean_gradient_tape(distribution) + class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, parameterized.TestCase): @@ -694,20 +824,31 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, num_workers=3, num_ps=2, has_chief=True) cls._default_target = 'grpc://' + cls._cluster_spec[CHIEF][0] - def testSimpleBetweenGraph(self): - self._run_between_graph_clients(self._test_simple_increment, - self._cluster_spec, context.num_gpus()) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testSimpleBetweenGraph(self, use_core_strategy): + self._run_between_graph_clients( + self._test_simple_increment, + self._cluster_spec, + context.num_gpus(), + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testMinimizeLossGraph(self, num_gpus): - self._run_between_graph_clients(self._test_minimize_loss_graph, - self._cluster_spec, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testMinimizeLossGraph(self, num_gpus, use_core_strategy): + self._run_between_graph_clients( + self._test_minimize_loss_graph, + self._cluster_spec, + num_gpus, + use_core_strategy=use_core_strategy) - def testGlobalStepIsWrapped(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) - with ops.Graph().as_default(), distribution.scope(): + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testGlobalStepIsWrappedOnTwoGPUs(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) + with ops.Graph().as_default(), strategy.scope(): created_step = training_util.create_global_step() get_step = training_util.get_global_step() self.assertEqual(created_step, get_step, @@ -716,19 +857,55 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, id(get_step), get_step.__class__.__name__))) self.assertIs(values.AggregatingVariable, type(created_step)) self.assertIs(values.AggregatingVariable, type(get_step)) + self.assertIs(strategy, created_step.distribute_strategy) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testGlobalStepIsNotWrappedOnOneGPU(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=1, use_core_strategy=use_core_strategy) + with ops.Graph().as_default(), strategy.scope(): + created_step = training_util.create_global_step() + get_step = training_util.get_global_step() + self.assertEqual(created_step, get_step, + msg=('created_step %s type %s vs. get_step %s type %s' % + (id(created_step), created_step.__class__.__name__, + id(get_step), get_step.__class__.__name__))) + self.assertIs(resource_variable_ops.ResourceVariable, type(created_step)) + self.assertIs(resource_variable_ops.ResourceVariable, type(get_step)) + # All variables have an _distribute_strategy parameter. Only variable + # subclasses in distribution strategy expose it publicly. + self.assertFalse(hasattr(strategy, 'distribute_strategy')) + self.assertIs(strategy, created_step._distribute_strategy) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testValueContainer(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) + with ops.Graph().as_default(), strategy.scope(): - def testValueContainer(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) - with ops.Graph().as_default(), distribution.scope(): def f(): with backprop.GradientTape() as tape: v = variable_scope.get_variable('v', initializer=10.0) _ = v * v v, = tape.watched_variables() - w = distribution.extended.value_container(v) + w = strategy.extended.value_container(v) self.assertIs(values.AggregatingVariable, type(w)) - distribution.extended.call_for_each_replica(f) + + strategy.extended.call_for_each_replica(f) + + +class LocalParameterServerStrategyTest(strategy_test_lib.DistributionTestBase, + parameterized.TestCase): + + @combinations.generate(combinations.combine(mode=['graph', 'eager'], + use_core_strategy=[True, False], + required_gpus=2)) + def testNumpyIterator(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) + self._test_numpy_iterator(strategy) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index faeb96bcb7c516b1e494661ef2cbe8dad476ab55..27aad46b97195aa498d0382f08c04c312cebbe65 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python.eager import backprop -from tensorflow.python.eager import context from tensorflow.python.training import optimizer as optimizer_lib @@ -33,6 +32,9 @@ class Step(object): def distribution(self): return self._distribution + def initialize(self): + return [] + def __call__(self): """Perform one step of this training algorithm.""" raise NotImplementedError("must be implemented in descendants") @@ -50,12 +52,10 @@ class StandardInputStep(Step): def __init__(self, dataset_fn, distribution): super(StandardInputStep, self).__init__(distribution) - self._distributed_input = distribution.distribute_dataset(dataset_fn) - if context.executing_eagerly(): - self._iterator = self._distributed_input.make_one_shot_iterator() - else: - # TODO(priyag): Expose initializer via some initializer property. - self._iterator = self._distributed_input.make_initializable_iterator() + self._iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn()) + + def initialize(self): + return self._iterator.initialize() class StandardSingleLossStep(StandardInputStep): @@ -99,7 +99,7 @@ class StandardSingleLossStep(StandardInputStep): gradients_fn = backprop.implicit_grad(self._loss_fn) gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) - grads_and_vars = self.distribution.call_for_each_replica( + grads_and_vars = self.distribution.extended.call_for_each_replica( gradients_fn, args=(ctx, inputs)) # If threads use layers, then we need to run the first step # sequentially, so that layers.build() is not executed in parallel. @@ -109,6 +109,6 @@ class StandardSingleLossStep(StandardInputStep): self.distribution, grads_and_vars) # TODO(priyag): Return the outputs, context, etc as well. - ctx = self.distribution.run_steps_on_dataset( + ctx = self.distribution.extended.experimental_run_steps_on_iterator( step_fn, self._iterator, self._iterations_per_step) return ctx.run_op diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py index 1ff9b9ceec13351b098d47ed3ff62f689a625a31..9f48560b2666036e149a63c98b6529fb24cc5067 100644 --- a/tensorflow/contrib/distribute/python/step_fn_test.py +++ b/tensorflow/contrib/distribute/python/step_fn_test.py @@ -45,24 +45,21 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase): single_loss_step, layer = single_loss_example( optimizer_fn, distribution, use_bias=True, iterations_per_step=2) - self.evaluate(distribution.initialize()) if context.executing_eagerly(): + single_loss_step.initialize() run_step = single_loss_step else: with self.cached_session() as sess: - sess.run(single_loss_step._iterator.initializer) + sess.run(single_loss_step.initialize()) run_step = sess.make_callable(single_loss_step()) self.evaluate(variables.global_variables_initializer()) weights, biases = [], [] for _ in range(5): run_step() - weights.append(self.evaluate(layer.kernel)) biases.append(self.evaluate(layer.bias)) - self.evaluate(distribution.finalize()) - error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) self.assertTrue(is_not_increasing) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index 6e5280e35632d3f3cb6a4fe172a15fb7f508354c..2e2ee92b6e20471f367895ea53c0864bb3d1dae7 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -18,7 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values @@ -31,6 +34,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -41,25 +45,26 @@ class _TestException(Exception): pass -# May be the argument to either distribution.call_for_each_replica() or +# May be the argument to either distribution.extended.call_for_each_replica() or # get_replica_context().merge_call() def _raise_exception_fn(_=None): raise _TestException() -# Must be the argument to a distribution.call_for_each_replica() call, calls a -# get_replica_context().merge_call() that raises an exception. +# Must be the argument to a distribution.extended.call_for_each_replica() call, +# calls a get_replica_context().merge_call() that raises an exception. def _merge_raises_fn(): ds_context.get_replica_context().merge_call(_raise_exception_fn) # Must be the argument to a get_replica_context().merge_call() call, calls -# dist.call_for_each_replica() with a function that raises an exception. +# dist.extended.call_for_each_replica() with a function that raises an +# exception. def _call_raises_fn(dist): - dist.call_for_each_replica(_raise_exception_fn) + dist.extended.call_for_each_replica(_raise_exception_fn) -# Must be the argument to a distribution.call_for_each_replica() call, +# Must be the argument to a distribution.extended.call_for_each_replica() call, # calls a get_replica_context().merge_call() that calls a # call_for_each_replica() that raises an exception. def _merge_call_raises_fn(): @@ -67,15 +72,16 @@ def _merge_call_raises_fn(): # Must be the argument to a get_replica_context().merge_call() call, calls -# dist.call_for_each_replica() with a function that calls a +# dist.extended.call_for_each_replica() with a function that calls a # get_replica_context().merge_call() that raises an exception. def _call_merge_raises_fn(dist): - dist.call_for_each_replica(_merge_raises_fn) + dist.extended.call_for_each_replica(_merge_raises_fn) -# Must be the argument to a distribution.call_for_each_replica() call, calls a -# get_replica_context().merge_call() that calls a call_for_each_replica() that -# calls a get_replica_context().merge_call() that raises an exception. +# Must be the argument to a distribution.extended.call_for_each_replica() call, +# calls a get_replica_context().merge_call() that calls a +# call_for_each_replica() that calls a get_replica_context().merge_call() that +# raises an exception. def _merge_call_merge_raises_fn(): ds_context.get_replica_context().merge_call(_call_merge_raises_fn) @@ -106,7 +112,7 @@ class DistributionTestBase(test.TestCase): def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, args=(one,)) + g_v = d.extended.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] @@ -118,8 +124,8 @@ class DistributionTestBase(test.TestCase): with ops.control_dependencies([fetched]): g = d.extended.reduce_to( reduce_util.ReduceOp.SUM, g, destinations=v) - with ops.control_dependencies(d.update( - v, update, g, grouped=False)): + with ops.control_dependencies(d.extended.update( + v, update, args=(g,), group=False)): after_list.append(d.extended.read_var(v)) return before_list, after_list @@ -162,7 +168,7 @@ class DistributionTestBase(test.TestCase): def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, args=(one,)) + g_v = d.extended.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] @@ -173,8 +179,8 @@ class DistributionTestBase(test.TestCase): with ops.control_dependencies([fetched]): g = d.extended.reduce_to( reduce_util.ReduceOp.SUM, g, destinations=v) - with ops.control_dependencies(d.update( - v, update, g, grouped=False)): + with ops.control_dependencies(d.extended.update( + v, update, args=(g,), group=False)): after_list.append(d.extended.read_var(v)) return before_list, after_list @@ -202,20 +208,20 @@ class DistributionTestBase(test.TestCase): self.assertFalse(expected_devices[replica_id]) expected_devices[replica_id] = True - d.call_for_each_replica(mark_devices_fn) + d.extended.call_for_each_replica(mark_devices_fn) self.assertAllEqual(expected_devices, [True] * len(d.extended.worker_devices)) def _test_call_and_merge_exceptions(self, dist): with dist.scope(): with self.assertRaises(_TestException): - dist.call_for_each_replica(_raise_exception_fn) + dist.extended.call_for_each_replica(_raise_exception_fn) with self.assertRaises(_TestException): - dist.call_for_each_replica(_merge_raises_fn) + dist.extended.call_for_each_replica(_merge_raises_fn) with self.assertRaises(_TestException): - dist.call_for_each_replica(_merge_call_raises_fn) + dist.extended.call_for_each_replica(_merge_call_raises_fn) with self.assertRaises(_TestException): - dist.call_for_each_replica(_merge_call_merge_raises_fn) + dist.extended.call_for_each_replica(_merge_call_merge_raises_fn) def _input_fn_to_test_input_context(self, dataset_fn, @@ -287,8 +293,195 @@ class DistributionTestBase(test.TestCase): value = global_step.read_value() return train_op, value - train_ops, value = strategy.call_for_each_replica(model_fn) + train_ops, value = strategy.extended.call_for_each_replica(model_fn) self.evaluate(strategy.group(train_ops)) global_step_tensors = strategy.unwrap(value) global_step_values = self.evaluate(global_step_tensors) self.assertEqual((1,) * len(global_step_tensors), global_step_values) + + def _test_numpy_iterator(self, strategy): + with strategy.scope(), self.cached_session() as sess: + x = np.asarray([[1, 2], [6, 12], [2, 4], + [5, 10], [3, 6], [4, 8]]) + y = np.asarray([5, 4, 3, 2, 1, 0]) + batch_size = 6 + if not strategy.extended._global_batch_size: # pylint: disable=protected-access + batch_size = batch_size // strategy.num_replicas_in_sync + i = strategy.experimental_make_numpy_iterator( + (x, y), batch_size=batch_size, num_epochs=2, shuffle=None, + session=sess) + self.evaluate(i.initialize()) + + def run_and_concatenate(strategy, i): + x, y = strategy.experimental_run(lambda z: z, i) + x, y = self.evaluate((strategy.unwrap(x), strategy.unwrap(y))) + return np.concatenate(x), np.concatenate(y) + + x_1, y_1 = run_and_concatenate(strategy, i) + self.assertAllEqual(x, x_1) + self.assertAllEqual(y, y_1) + x_2, y_2 = run_and_concatenate(strategy, i) + self.assertAllEqual(x, x_2) + self.assertAllEqual(y, y_2) + with self.assertRaises(errors.OutOfRangeError): + run_and_concatenate(strategy, i) + + +class OneDeviceDistributionTestBase(test.TestCase): + """Some tests that should work with any one-device DistributionStrategy.""" + + def _test_all_reduce_sum(self, strategy): + self._test_collective_comms( + strategy, _all_sum, inputs=(4., [42., 43.]), expected=(4., [42., 43.])) + + def _test_all_reduce_sum_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_sum, inputs=[4.], expected_grads=[4.]) + + def _test_all_reduce_sum_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_sum, inputs=[4.], expected_grads=[4.]) + + def _test_all_reduce_mean(self, strategy): + self._test_collective_comms( + strategy, _all_mean, inputs=(2., [21., 22.]), expected=(2., [21., 22.])) + + def _test_all_reduce_mean_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_mean, inputs=[5.], expected_grads=[5.]) + + def _test_all_reduce_mean_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_mean, inputs=[5.], expected_grads=[5.]) + + def _test_collective_comms(self, strategy, comm_fn, inputs, expected): + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensors(inputs)) + + self.evaluate(inputs.initialize()) + outputs = self.evaluate( + list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs)))) + self.assertAllEqual([expected[0]], outputs[0]) + self.assertAllEqual([expected[1]], outputs[1]) + + def _test_collective_comms_gradients( + self, strategy, comm_fn, inputs, expected_grads): + if context.executing_eagerly(): + self.skipTest("`tf.gradients` is not supported with eager execution.") + + def step(c): + x = constant_op.constant(42.) + y = comm_fn(x) * c + return gradients_impl.gradients(y, [x])[0] + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensors(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + def _test_collective_comms_gradient_tape( + self, strategy, comm_fn, inputs, expected_grads): + def step(c): + x = constant_op.constant(42.) + with backprop.GradientTape() as tape: + tape.watch(x) + y = comm_fn(x) * c + return tape.gradient(y, x) + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensors(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + +class TwoDeviceDistributionTestBase(test.TestCase): + """Some tests that should work with any two-device DistributionStrategy.""" + + def _test_all_reduce_sum(self, strategy): + self._test_collective_comms( + strategy, _all_sum, + inputs=([1., 3.], [[39., 2.], [3., 41.]]), + expected=(4., [42., 43.])) + + def _test_all_reduce_sum_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.]) + + def _test_all_reduce_sum_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.]) + + def _test_all_reduce_mean(self, strategy): + self._test_collective_comms( + strategy, _all_mean, + inputs=([1., 3.], [[39., 2.], [3., 41.]]), + expected=(2., [21., 21.5])) + + def _test_all_reduce_mean_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.]) + + def _test_all_reduce_mean_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.]) + + def _test_collective_comms(self, strategy, comm_fn, inputs, expected): + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) + + self.evaluate(inputs.initialize()) + outputs = self.evaluate( + list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs)))) + self.assertAllEqual([expected[0], expected[0]], outputs[0]) + self.assertAllEqual([expected[1], expected[1]], outputs[1]) + + def _test_collective_comms_gradients( + self, strategy, comm_fn, inputs, expected_grads): + if context.executing_eagerly(): + self.skipTest("`tf.gradients` is not supported with eager execution.") + + def step(c): + x = constant_op.constant(42.) + y = comm_fn(x) * c + return gradients_impl.gradients(y, [x])[0] + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + def _test_collective_comms_gradient_tape( + self, strategy, comm_fn, inputs, expected_grads): + def step(c): + x = constant_op.constant(42.) + with backprop.GradientTape() as tape: + tape.watch(x) + y = comm_fn(x) * c + return tape.gradient(y, x) + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + +def _all_sum(value): + ctx = ds_context.get_replica_context() + return ctx.all_reduce(reduce_util.ReduceOp.SUM, value) + + +def _all_mean(value): + ctx = ds_context.get_replica_context() + return ctx.all_reduce(reduce_util.ReduceOp.MEAN, value) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 7352203fe11b3036229119e06872aed5e160b715..4387210062e42bb1ab7e2351008a45979224ff1a 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -21,10 +21,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import copy -import functools from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import device_assignment as device_assignment_lib +from tensorflow.contrib.tpu.python.tpu import topology from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.contrib.tpu.python.tpu import training_loop @@ -33,12 +35,15 @@ from tensorflow.python.client import session as session_lib from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib +from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver_lib from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op +from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util @@ -50,6 +55,29 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +def initialize_tpu_system(cluster_resolver=None): + """Initialize the TPU devices in a separate session and graph. + + Args: + cluster_resolver: A tf.contrib.cluster_resolver.TPUClusterResolver, + which provides information about the TPU cluster. + Returns: + The tf.contrib.tpu.Topology object for the topology of the TPU cluster. + """ + if cluster_resolver is None: + cluster_resolver = resolver_lib.TPUClusterResolver("") + master = cluster_resolver.master() + + logging.info("Initializing the TPU system.") + session_config = config_pb2.ConfigProto(allow_soft_placement=True) + + with ops.Graph().as_default(): + with session_lib.Session(config=session_config, target=master) as sess: + serialized_topology = sess.run(tpu.initialize_system()) + logging.info("Finished initializing TPU system.") + return topology.Topology(serialized=serialized_topology) + + def get_tpu_system_metadata(tpu_cluster_resolver): """Retrieves TPU system metadata given a TPUClusterResolver.""" master = tpu_cluster_resolver.master() @@ -68,12 +96,13 @@ def get_tpu_system_metadata(tpu_cluster_resolver): # TODO(jhseu): Deduplicate with MirroredStrategy? def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring - device_map, logical_device, real_mirrored_creator, *args, **kwargs): + strategy, device_map, logical_device, real_mirrored_creator, + *args, **kwargs): # Figure out what collections this variable should be added to. # We'll add the TPUMirroredVariable to those collections instead. - collections = kwargs.pop("collections", None) - if collections is None: - collections = [ops.GraphKeys.GLOBAL_VARIABLES] + var_collections = kwargs.pop("collections", None) + if var_collections is None: + var_collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] # TODO(jhseu): Should we have different behavior for different @@ -101,7 +130,8 @@ def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring devices = device_map.logical_to_actual_devices(logical_device) value_list = real_mirrored_creator(devices, *args, **kwargs) result = values.TPUMirroredVariable( - device_map, value_list, aggregation, logical_device=logical_device) + strategy, device_map, value_list, aggregation, + logical_device=logical_device) if not context.executing_eagerly(): g = ops.get_default_graph() @@ -111,11 +141,11 @@ def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): - collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in value_list: l.remove(v) - g.add_to_collections(collections, result) + g.add_to_collections(var_collections, result) return result @@ -125,7 +155,8 @@ class TPUStrategy(distribute_lib.DistributionStrategy): def __init__(self, tpu_cluster_resolver=None, steps_per_run=None, - num_cores=None): + device_assignment=None, + **kwargs): """Initializes the TPUStrategy object. Args: @@ -136,31 +167,82 @@ class TPUStrategy(distribute_lib.DistributionStrategy): metrics, summaries etc. This parameter is only used when Distribution Strategy is used with estimator or keras. - num_cores: Number of cores to use on the TPU. If None specified, then - auto-detect the cores and topology of the TPU system. + device_assignment: Optional `tf.contrib.tpu.DeviceAssignment` to specify + the placement of replicas on the TPU cluster. Currently only supports + the usecase of using a single core within a TPU cluster. + **kwargs: Additional experimental flags. Will be removed in future. """ + if len(kwargs) > 1: + raise ValueError("TPUStrategy constructor only takes one experimental " + "flag now") + elif len(kwargs) == 1 and "_disable_training_loop_on_host" not in kwargs: + raise ValueError("TPUStrategy constructor does not support arguments: " + "{}".format(kwargs)) + super(TPUStrategy, self).__init__(TPUExtended( - self, tpu_cluster_resolver, steps_per_run, num_cores)) + self, tpu_cluster_resolver, steps_per_run, device_assignment, + kwargs.get("_disable_training_loop_on_host", False))) @property def steps_per_run(self): """DEPRECATED: use .extended.steps_per_run instead.""" return self._extended.steps_per_run + # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this + # can use the default implementation. + # This implementation runs a single step. It does not use infeed or outfeed. + def experimental_run(self, fn, input_iterator=None): + """See base class.""" + if context.executing_eagerly(): + raise NotImplementedError("Eager mode not supported in TPUStrategy.") + + if self.extended._disable_training_loop_on_host: # pylint: disable=protected-access + raise NotImplementedError( + "`experimental_run` is not compatible with " + "`_disable_training_loop_on_host=True`") + + if input_iterator is None: + inputs = [] + else: + inputs = input_iterator.get_next() + + result = [None] + def replicated_fn(replica_id, inputs): + """Wraps user function to provide replica ID and `Tensor` inputs.""" + with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): + if input_iterator is None: + result[0] = fn() + else: + result[0] = fn(inputs) + return result[0] + + replicate_inputs = [] # By replica. + for i in range(self.num_replicas_in_sync): + replicate_inputs.append( + [constant_op.constant(i, dtype=dtypes.int32), + values.select_replica(i, inputs)]) + + with self.scope(): + replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) + + # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. + replicate_outputs = [ + nest.pack_sequence_as(result[0], nest.flatten(replica_outputs)) + for replica_outputs in replicate_outputs] + + device_map = self.extended._device_map # pylint: disable=protected-access + return values.regroup(device_map, replicate_outputs) + class TPUExtended(distribute_lib.DistributionStrategyExtended): """Implementation of TPUStrategy.""" - # Track what TPU devices have been initialized. This is *intentionally* - # shared across all instances of TPUExtended as we want to keep track of which - # devices are initialized globally. - _initialized_devices = [] - def __init__(self, container_strategy, tpu_cluster_resolver=None, steps_per_run=None, - num_cores=None): + device_assignment=None, + disable_training_loop_on_host=False): super(TPUExtended, self).__init__(container_strategy) if tpu_cluster_resolver is None: @@ -173,8 +255,22 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) - # TODO(sourabhbajaj): Change this from num_cores to metadata_override - self._num_cores_override = num_cores + self._device_assignment = device_assignment + self._disable_training_loop_on_host = disable_training_loop_on_host + + # Device assignment is currently only supported for 1 core case. + if self._device_assignment: + assert isinstance(self._device_assignment, + device_assignment_lib.DeviceAssignment) + if self._device_assignment.num_replicas != 1: + raise ValueError("Device assignment is only supported for a single " + "core single replica case currently.") + if self._device_assignment.num_cores_per_replica != 1: + raise ValueError("Device assignment is only supported for a single " + "core single replica case currently.") + if not all(self._device_assignment.core_assignment[0][0] == [0, 0, 0]): + raise ValueError("Device assignment is only supported for a single " + "core single replica case currently.") # TODO(jhseu): Switch to DeviceAssignment to support pods and model # parallelism. @@ -188,45 +284,33 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] self._device_map = values.ReplicaDeviceMap(self._tpu_devices) - # For input: - input_device_map = values.ReplicaDeviceMap(tuple( - self.get_host_cpu_device(hid) for hid in range(self.num_hosts))) - worker_devices = [ - (self.get_host(hid), [self.get_host_cpu_device(hid)]) - for hid in range(self.num_hosts) - ] - self._input_workers = values.InputWorkers(input_device_map, worker_devices) + # If the training loop is on the device, we must use the infeed, with input + # on the host. Otherwise, we preload the data onto the TPUs. + if disable_training_loop_on_host: + input_device_map = values.ReplicaDeviceMap(tuple( + self.get_host_cpu_device(hid) for hid in range(self.num_hosts))) + worker_devices = [ + (self.get_host(hid), [self.get_host_cpu_device(hid)]) + for hid in range(self.num_hosts) + ] + self._input_workers = input_lib.InputWorkers( + input_device_map, worker_devices) + else: + input_worker_devices = collections.OrderedDict() + for tpu_device in self._tpu_devices: + host_device = _get_host_for_device(tpu_device) + input_worker_devices.setdefault(host_device, []) + input_worker_devices[host_device].append(tpu_device) + self._input_workers = input_lib.InputWorkers( + self._device_map, tuple(input_worker_devices.items())) # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. self.steps_per_run = steps_per_run self._require_static_shapes = True - # Initialize the TPU devices. - self._initialize_tpu() - - def _initialize_tpu(self): - """Initialize the TPU devices in a separate session and graph. - - We keep track of all the TPU devices that we're initialized as we should - only be running TPU initialize once for the entire process. - """ - master = self._tpu_cluster_resolver.master() - # Verify TPU has not already been initialized in this process. - if master in TPUExtended._initialized_devices: - logging.info("TPU master %s has already been initialized." % master) - return - - logging.info("Initializing the TPU system.") - session_config = config_pb2.ConfigProto(allow_soft_placement=True) - self._configure(session_config) - with ops.Graph().as_default(): - with session_lib.Session(config=session_config, target=master) as sess: - sess.run([tpu.initialize_system()]) - logging.info("Finized initializing TPU system.") - - # Update Strategy state to make sure we can track device initialization. - TPUExtended._initialized_devices.append(master) + def _validate_colocate_with_variable(self, colocate_with_variable): + values.validate_colocate_tpu_variable(colocate_with_variable, self) def _get_enqueue_op_per_host(self, host_id, multi_worker_iterator, input_shapes, iterations): @@ -291,20 +375,44 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): def _make_dataset_iterator(self, dataset): """Make iterators for each of the TPU hosts.""" - - return values.DatasetIterator(dataset, self._input_workers, - self._num_replicas_in_sync) - - def _distribute_dataset(self, dataset_fn): - return values.MultiWorkerDataset( - functools.partial(self._call_dataset_fn, dataset_fn), - self._input_workers) + return input_lib.DatasetIterator(dataset, self._input_workers, + self._num_replicas_in_sync) + + def _make_input_fn_iterator( + self, + input_fn, + replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): + input_contexts = [] + num_workers = self._input_workers.num_workers + for i in range(num_workers): + input_contexts.append(distribute_lib.InputContext( + num_input_pipelines=num_workers, + input_pipeline_id=i, + num_replicas_in_sync=self._num_replicas_in_sync)) + return input_lib.InputFunctionIterator( + input_fn, self._input_workers, input_contexts) + + def _experimental_make_numpy_dataset(self, numpy_input, session): + return numpy_dataset.one_host_numpy_dataset( + numpy_input, numpy_dataset.SingleDevice(self.get_host_cpu_device(0)), + session) # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have # a mechanism to infer the outputs of `fn`. Pending b/110550782. def _experimental_run_steps_on_iterator( self, fn, multi_worker_iterator, iterations, initial_loop_values=None): + if self._disable_training_loop_on_host: + impl = self._run_steps_on_iterator_with_device_loop + else: + impl = self._run_steps_on_iterator_with_host_loop + + return impl( + fn=fn, multi_worker_iterator=multi_worker_iterator, + iterations=iterations, initial_loop_values=initial_loop_values) + + def _run_steps_on_iterator_with_host_loop( + self, fn, multi_worker_iterator, iterations, initial_loop_values=None): output_shapes = multi_worker_iterator.output_shapes shapes = nest.flatten(output_shapes) if any(not s.is_fully_defined() for s in shapes): @@ -312,26 +420,16 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): "TPU currently requires fully defined shapes. Either use " "set_shape() on the input tensors or use " "dataset.batch(..., drop_remainder=True).") - types = nest.flatten(multi_worker_iterator.output_types) - - enqueue_ops = [ - self._get_enqueue_op_per_host(host_id, multi_worker_iterator, shapes, - iterations) - for host_id in range(self.num_hosts)] - - def dequeue_fn(): - dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) - return nest.pack_sequence_as(output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) - ctx = values.MultiStepContext() + ctx = input_lib.MultiStepContext() - def run_fn(): + def run_fn(inputs): """Single step on the TPU device.""" - fn_result = fn(ctx, dequeue_fn()) + fn_result = fn(ctx, inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): @@ -351,7 +449,14 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): def rewrite_fn(*args): """The rewritten step fn running on TPU.""" del args - replicate_inputs = [[]] * self._num_replicas_in_sync + + per_replica_inputs = multi_worker_iterator.get_next() + replicate_inputs = [] + for replica_id in range(self._num_replicas_in_sync): + select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop + replicate_inputs.append((nest.map_structure( + select_replica, per_replica_inputs),)) + replicate_outputs = tpu.replicate(run_fn, replicate_inputs) # If run_fn has tensor outputs, tpu.replicate returns a list of list. We @@ -363,8 +468,8 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): return replicate_outputs - # TODO(sourabhbajaj): The input to while loop should be based on the output - # type of the step_fn + # TODO(sourabhbajaj): The input to while loop should be based on the + # output type of the step_fn assert isinstance(initial_loop_values, list) initial_loop_values = initial_loop_values * self._num_replicas_in_sync @@ -374,7 +479,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): initial_loop_values) del self._outer_control_flow_context - ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) + ctx.run_op = control_flow_ops.group(replicate_outputs) if isinstance(replicate_outputs, list): # Filter out any ops from the outputs, typically this would be the case @@ -399,23 +504,80 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): # no tensors returned. last_step_tensor_outputs = [] - # Convert replicate_outputs to the original dict structure of - # last_step_outputs. - last_step_tensor_outputs_dict = nest.pack_sequence_as( - ctx.last_step_outputs, last_step_tensor_outputs) - - for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access - output = last_step_tensor_outputs_dict[name] - # For outputs that have already been reduced, take the first value - # from the list as each value should be the same. Else return the full - # list of values. - # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica - # value. - if reduce_op is not None: - # TODO(priyag): Should this return the element or a list with 1 element - last_step_tensor_outputs_dict[name] = output[0] - ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access + _set_last_step_outputs(ctx, last_step_tensor_outputs) + return ctx + def _run_steps_on_iterator_with_device_loop( + self, fn, multi_worker_iterator, iterations, initial_loop_values=None): + output_shapes = multi_worker_iterator.output_shapes + shapes = nest.flatten(output_shapes) + if any(not s.is_fully_defined() for s in shapes): + raise ValueError( + "TPU currently requires fully defined shapes. Either use " + "set_shape() on the input tensors or use " + "dataset.batch(..., drop_remainder=True).") + types = nest.flatten(multi_worker_iterator.output_types) + + enqueue_ops = [ + self._get_enqueue_op_per_host(host_id, multi_worker_iterator, shapes, + iterations) + for host_id in range(self.num_hosts)] + + def dequeue_fn(): + dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) + return nest.pack_sequence_as(output_shapes, dequeued) + + # Wrap `fn` for repeat. + if initial_loop_values is None: + initial_loop_values = {} + initial_loop_values = nest.flatten(initial_loop_values) + ctx = input_lib.MultiStepContext() + + def run_fn(*args, **kwargs): + """Single step on the TPU device.""" + del args, kwargs + fn_result = fn(ctx, dequeue_fn()) + flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) + if flat_last_step_outputs: + with ops.control_dependencies([fn_result]): + return [array_ops.identity(f) for f in flat_last_step_outputs] + else: + return fn_result + + def iterate_on_tpu(): + return training_loop.repeat(iterations, run_fn, initial_loop_values) + + # We capture the control_flow_context at this point, before we run `fn` + # inside a while_loop and TPU replicate context. This is useful in cases + # where we might need to exit these contexts and get back to the outer + # context to do some things, for e.g. create an op which should be + # evaluated only once at the end of the loop on the host. One such usage + # is in creating metrics' value op. + self._outer_control_flow_context = ( + ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access + + replicate_inputs = [[]] * self._num_replicas_in_sync + replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) + + del self._outer_control_flow_context + ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) + + # Filter out any ops from the outputs, typically this would be the case + # when there were no tensor outputs. + last_step_tensor_outputs = [x for x in replicate_outputs + if not isinstance(x, ops.Operation)] + + # Outputs are currently of the structure (grouped by device) + # [[output0_device0, output1_device0, output2_device0], + # [output0_device1, output1_device1, output2_device1]] + # Convert this to the following structure instead: (grouped by output) + # [[output0_device0, output0_device1], + # [output1_device0, output1_device1], + # [output2_device0, output2_device1]] + last_step_tensor_outputs = [list(x) for x in + zip(*last_step_tensor_outputs)] + + _set_last_step_outputs(ctx, last_step_tensor_outputs) return ctx def _call_for_each_replica(self, fn, args, kwargs): @@ -424,19 +586,13 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): with _TPUReplicaContext(self._container_strategy()): return fn(*args, **kwargs) - def _initialize(self): - if context.executing_eagerly(): - # TODO(priyag): Add appopriate call here when eager is supported for TPUs. - raise NotImplementedError("Eager mode not supported in TPUStrategy.") - else: - return [] + def _experimental_initialize_system(self): + """Experimental method added to be used by Estimator. - def _finalize(self): - if context.executing_eagerly(): - # TODO(priyag): Add appopriate call here when eager is supported for TPUs. - raise NotImplementedError("Eager mode not supported in TPUStrategy.") - else: - return [] + This is a private method only to be used by Estimator. Other frameworks + should directly be calling `tf.contrib.distribute.initialize_tpu_system` + """ + initialize_tpu_system(self._tpu_cluster_resolver) def _create_variable(self, next_creator, *args, **kwargs): """Create a TPUMirroredVariable. See `DistributionStrategy.scope`.""" @@ -444,6 +600,9 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): if colocate_with is None: device_map = self._device_map logical_device = 0 # TODO(josh11b): Get logical device from scope here. + elif isinstance(colocate_with, numpy_dataset.SingleDevice): + with ops.device(colocate_with.device): + return next_creator(*args, **kwargs) else: device_map = colocate_with.device_map logical_device = colocate_with.logical_device @@ -475,7 +634,8 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): return value_list return _create_tpu_mirrored_variable( - device_map, logical_device, _real_mirrored_creator, *args, **kwargs) + self._container_strategy(), device_map, logical_device, + _real_mirrored_creator, *args, **kwargs) def _reduce_to(self, reduce_op, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access @@ -559,15 +719,34 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): @property def num_hosts(self): - return self._tpu_metadata.num_hosts + if self._device_assignment is None: + return self._tpu_metadata.num_hosts + + return len(set([self._device_assignment.host_device(r) + for r in range(self._device_assignment.num_replicas)])) @property def num_replicas_per_host(self): - return self._tpu_metadata.num_of_cores_per_host + if self._device_assignment is None: + return self._tpu_metadata.num_of_cores_per_host + + # TODO(sourabhbajaj): Remove this method we use inputs and remove infeed + # as the computation of num_replicas_per_host is not a constant + # when using device_assignment. This is a temporary workaround to support + # StatefulRNN as everything is 1 in that case. + # This method needs to take host_id as input for correct computation. + max_models_per_host = (self._tpu_metadata.num_of_cores_per_host // + self._device_assignment.num_cores_per_replica) + models_per_host = min(self._device_assignment.num_replicas, + max_models_per_host) + return models_per_host * self._device_assignment.num_cores_per_replica @property def _num_replicas_in_sync(self): - return self._num_cores_override or self._tpu_metadata.num_cores + if self._device_assignment is None: + return self._tpu_metadata.num_cores + return (self._device_assignment.num_replicas * + self._device_assignment.num_cores_per_replica) @property def experimental_between_graph(self): @@ -635,6 +814,13 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): + """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. + + `make_input_fn_iterator` assumes per-replica batching. + + Returns: + Boolean. + """ return True @@ -642,15 +828,48 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext): """Replication Context class for TPU Strategy.""" # TODO(sourabhbajaj): Call for each replica should be updating this. - def __init__(self, strategy): - # TODO(b/118385803): properly initialize replica_id, instead of always 0 - replica_id = constant_op.constant(0, dtypes.int32) + # TODO(b/118385803): Always properly initialize replica_id. + def __init__(self, strategy, replica_id_in_sync_group=None): + if replica_id_in_sync_group is None: + replica_id_in_sync_group = constant_op.constant(0, dtypes.int32) distribute_lib.ReplicaContext.__init__( - self, strategy, replica_id_in_sync_group=replica_id) + self, strategy, replica_id_in_sync_group=replica_id_in_sync_group) @property def devices(self): distribute_lib.require_replica_context(self) ds = self._strategy replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) - return (ds.extended.worker_devices[replica_id],) + + if replica_id is None: # Non-constant `Tensor` inside `tpu.replicate`. + # TODO(cjfj): Return other devices when model parallelism is supported. + return (tpu.core(0),) + else: + return (ds.extended.worker_devices[replica_id],) + + +def _get_host_for_device(device): + spec = tf_device.DeviceSpec.from_string(device) + return tf_device.DeviceSpec( + job=spec.job, replica=spec.replica, task=spec.task, + device_type="CPU", device_index=0).to_string() + + +def _set_last_step_outputs(ctx, last_step_tensor_outputs): + """Sets the last step outputs on the given context.""" + # Convert replicate_outputs to the original dict structure of + # last_step_outputs. + last_step_tensor_outputs_dict = nest.pack_sequence_as( + ctx.last_step_outputs, last_step_tensor_outputs) + + for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access + output = last_step_tensor_outputs_dict[name] + # For outputs that have already been reduced, take the first value + # from the list as each value should be the same. Else return the full + # list of values. + # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica + # value. + if reduce_op is not None: + # TODO(priyag): Should this return the element or a list with 1 element + last_step_tensor_outputs_dict[name] = output[0] + ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 0e8e86f6b9647ebf06890c9bb343a8f8e0fcc698..51c58b0b2f3dc2ab63e22718825a471b8657f892 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -22,28 +22,20 @@ import os from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.data.experimental.ops import batching -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import device_util -from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import constant_op -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import saver as saver_lib -from tensorflow.python.util import nest class DistributedValuesTest(test.TestCase): @@ -191,7 +183,7 @@ def _make_mirrored(): v.append(variable_scope.get_variable( name=n, initializer=init, use_resource=True)) device_map = values.ReplicaDeviceMap(devices) - mirrored = values.MirroredVariable(device_map, v, + mirrored = values.MirroredVariable(None, device_map, v, variable_scope.VariableAggregation.SUM) return v, device_map, mirrored @@ -314,7 +306,7 @@ class RegroupAndSelectDeviceTest(test.TestCase): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) device_map = values.ReplicaDeviceMap((d,)) - mirrored = values.MirroredVariable(device_map, (v,), + mirrored = values.MirroredVariable(None, device_map, (v,), variable_scope.VariableAggregation.SUM) result = values.regroup(device_map, (v,)) self.assertIs(mirrored, result) @@ -354,444 +346,6 @@ class RegroupAndSelectDeviceTest(test.TestCase): merged_estimator_spec)) -class PerReplicaDatasetTest(test.TestCase): - - config = config_pb2.ConfigProto() - config.allow_soft_placement = True - - def _test_iterator(self, devices, dataset, expected_values): - device_map = values.ReplicaDeviceMap(devices) - input_workers = values.InputWorkers(device_map) - per_replica_dataset = values.PerReplicaDataset(dataset, input_workers, 0) - if context.executing_eagerly(): - iterator = per_replica_dataset.make_one_shot_iterator() - else: - iterator = per_replica_dataset.make_initializable_iterator() - self.evaluate([iterator.initializer]) - - for expected_value in expected_values: - next_element = iterator.get_next_as_list() - computed_value = self.evaluate(next_element) - self.assertEqual(expected_value, computed_value) - - with self.assertRaises(errors.OutOfRangeError): - next_element = iterator.get_next_as_list() - self.evaluate(next_element) - - @test_util.run_in_graph_and_eager_modes - def testOneDevice(self): - devices = ["/device:CPU:0"] - dataset = dataset_ops.Dataset.range(10) - - expected_values = [[i] for i in range(10)] - - self._test_iterator(devices, dataset, expected_values) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testMultipleDevices(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - devices = ["/device:CPU:0", "/device:GPU:0"] - dataset = dataset_ops.Dataset.range(10) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - - self._test_iterator(devices, dataset, expected_values) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testTupleDataset(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - devices = ["/device:CPU:0", "/device:GPU:0"] - dataset1 = dataset_ops.Dataset.range(10) - dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) - dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] - - self._test_iterator(devices, dataset, expected_values) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testUnevenDatasetBatches(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - devices = ["/device:CPU:0", "/device:GPU:0"] - dataset = dataset_ops.Dataset.range(11) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - self._test_iterator(devices, dataset, expected_values) - - def testInitializableIterator(self): - with context.graph_mode(): - devices = ["/device:CPU:0"] - # Using random input since that is only allowed with initializable - # iterator. - dataset = dataset_ops.Dataset.from_tensor_slices( - random_ops.random_uniform((10,))) - - device_map = values.ReplicaDeviceMap(devices) - input_workers = values.InputWorkers(device_map) - per_replica_dataset = values.PerReplicaDataset(dataset, input_workers, 0) - iterator = per_replica_dataset.make_initializable_iterator() - - self.evaluate(iterator.initializer) - next_element = iterator.get_next_as_list() - for _ in range(10): - self.evaluate(next_element) - - # Should fail after the input is finished. - with self.assertRaises(errors.OutOfRangeError): - self.evaluate(next_element) - - # After re-initializing the iterator, should be able to iterate again. - self.evaluate(iterator.initializer) - for _ in range(10): - self.evaluate(next_element) - - -class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): - - def _test_iterator(self, sess, iterator, devices, expected_values): - next_element = iterator.get_next() - for r, device in enumerate(devices): - v = values.select_replica(r, next_element) - # The `v` here can be a tuple. - for element in nest.flatten(v): - self.assertTrue(element.device in device) - - for expected_value in expected_values: - t = [values.select_replica(r, next_element) for r in range(len(devices))] - actual = sess.run(t) - self.assertEqual(expected_value, actual) - - with self.assertRaises(errors.OutOfRangeError): - sess.run([values.select_replica(r, next_element) - for r in range(len(devices))]) - - def _test_dataset(self, dataset_fn, worker_devices, devices, - expected_values): - device_map = values.ReplicaDeviceMap(devices) - input_workers = values.InputWorkers(device_map, worker_devices) - multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, input_workers) - multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() - with self.cached_session() as sess: - sess.run(multi_worker_iterator.initializer) - self._test_iterator(sess, multi_worker_iterator, devices, expected_values) - - def _cpu_devices(self): - worker_devices = ( - ("/job:worker/replica:0/task:0", - ["/job:worker/replica:0/task:0/device:CPU:0"]), - ("/job:worker/replica:0/task:1", - ["/job:worker/replica:0/task:1/device:CPU:0"]) - ) - devices = [ - "/job:worker/replica:0/task:0/device:CPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ] - return worker_devices, devices - - def _cpu_and_one_gpu_devices(self): - worker_devices = ( - ("/job:worker/replica:0/task:0", ( - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:0/device:CPU:0" - )), - ("/job:worker/replica:0/task:1", ( - "/job:worker/replica:0/task:1/device:GPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - )) - ) - devices = [ - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:0/device:CPU:0", - "/job:worker/replica:0/task:1/device:GPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ] - return worker_devices, devices - - def testDataDistributionOneDevicePerWorker(self): - worker_devices, devices = self._cpu_devices() - with context.graph_mode(): - dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset( - dataset_fn, worker_devices, devices, - [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) - - def testDataDistributionTwoDevicePerWorker(self): - if context.num_gpus() < 1: - self.skipTest("A GPU is not available for this test.") - worker_devices, devices = self._cpu_and_one_gpu_devices() - with context.graph_mode(): - dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset( - dataset_fn, worker_devices, devices, - [[0, 1, 0, 1], [2, 3, 2, 3], [4, 5, 4, 5], [6, 7, 6, 7]]) - - def testTupleDataset(self): - worker_devices, devices = self._cpu_devices() - - with context.graph_mode(): - - def dataset_fn(): - dataset1 = dataset_ops.Dataset.range(8) - dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i, i**2)] for i in range(8)] - self._test_dataset(dataset_fn, worker_devices, devices, - expected_values) - - def testInitializableIterator(self): - worker_devices, devices = self._cpu_devices() - with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(8) - device_map = values.ReplicaDeviceMap(devices) - input_workers = values.InputWorkers(device_map, worker_devices) - multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, input_workers) - multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() - - sess.run(multi_worker_iterator.initializer) - self._test_iterator( - sess, multi_worker_iterator, devices, - [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) - - # After re-initializing the iterator, should be able to iterate again. - sess.run(multi_worker_iterator.initializer) - self._test_iterator( - sess, multi_worker_iterator, devices, - [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) - - def testValueErrorForIterator(self): - # Incompatiable arguments. - d1 = "/device:GPU:0" - d2 = "/device:GPU:1" - device_map = values.ReplicaDeviceMap([d1, d2]) - input_workers = values.InputWorkers( - device_map, (("w1", (d1,)), ("w2", (d2,)))) - with self.assertRaises(ValueError): - values.MultiWorkerDataIterator([("w1", None)], input_workers) - - def testDuplicateDevices(self): - _, devices = self._cpu_devices() - devices.append("/job:worker/replica:0/task:0/device:CPU:0") - with self.assertRaises(ValueError): - _ = values.ReplicaDeviceMap(devices) - - -class InputIteratorTestBase(test.TestCase): - - def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, - expected_values, sess=None, split_batch_by=None): - devices = nest.flatten([ds for _, ds in worker_device_pairs]) - device_map = values.ReplicaDeviceMap(devices) - input_workers = values.InputWorkers(device_map, worker_device_pairs) - - if input_type == "input_fn": - input_contexts = [ - distribute_lib.InputContext() for _ in worker_device_pairs] - input_fn = lambda _: dataset_fn() - iterator = values.InputFunctionIterator( - input_fn, input_workers, input_contexts) - else: - iterator = values.DatasetIterator( - dataset_fn(), input_workers, split_batch_by) - - evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) - - evaluate(control_flow_ops.group(iterator.initialize())) - - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = evaluate( - [values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertAllEqual(expected_value, computed_value) - - with self.assertRaises(errors.OutOfRangeError): - next_element = iterator.get_next() - evaluate([values.select_replica(r, next_element) - for r in range(len(devices))]) - - # After re-initializing the iterator, should be able to iterate again. - evaluate(control_flow_ops.group(iterator.initialize())) - - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = evaluate( - [values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertAllEqual(expected_value, computed_value) - - -class InputIteratorSingleWorkerTest(InputIteratorTestBase, - parameterized.TestCase): - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"])) - def testOneDeviceCPU(self, input_type): - worker_device_pairs = [("", ["/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(10) - - expected_values = [[i] for i in range(10)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testTwoDevicesOneGPUOneCPU(self, input_type): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(10) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testTupleDataset(self, input_type): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - def dataset_fn(): - dataset1 = dataset_ops.Dataset.range(10) - dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testUnevenDatasetBatches(self, input_type): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(11) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["dataset"], - split_batch_by=[None, 2], - required_gpus=1)) - def testBatchSplitting(self, input_type, split_batch_by): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - batch_size = 10 - dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size) - - updated_batch_size = ( - batch_size // split_batch_by if split_batch_by else batch_size) - expected_values = [[range(i, i+updated_batch_size), - range(i+updated_batch_size, i+2*updated_batch_size)] - for i in range(0, 100, updated_batch_size*2)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values, sess=None, - split_batch_by=split_batch_by) - - -class InputIteratorMultiWorkerTest( - multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase, - parameterized.TestCase): - - def _cpu_devices(self): - return [ - ("/job:worker/replica:0/task:0", - ["/job:worker/replica:0/task:0/device:CPU:0"]), - ("/job:worker/replica:0/task:1", - ["/job:worker/replica:0/task:1/device:CPU:0"])] - - def _cpu_and_one_gpu_devices(self): - return [ - ("/job:worker/replica:0/task:0", [ - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:0/device:CPU:0" - ]), - ("/job:worker/replica:0/task:1", [ - "/job:worker/replica:0/task:1/device:GPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ]) - ] - - @combinations.generate(combinations.combine( - mode=["graph"], - input_type=["input_fn", "dataset"])) - def testOneDevicePerWorker(self, input_type): - worker_devices = self._cpu_devices() - with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(4) - self._test_iterator(input_type, dataset_fn, worker_devices, - [[0, 0], [1, 1], [2, 2], [3, 3]], sess) - - @combinations.generate(combinations.combine( - mode=["graph"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testTwoDevicesPerWorker(self, input_type): - worker_devices = self._cpu_and_one_gpu_devices() - with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(4) - self._test_iterator(input_type, dataset_fn, worker_devices, - [[0, 1, 0, 1], [2, 3, 2, 3]], sess) - - @combinations.generate(combinations.combine( - mode=["graph"], - input_type=["input_fn", "dataset"])) - def testTupleDataset(self, input_type): - worker_devices = self._cpu_devices() - with context.graph_mode(), self.cached_session() as sess: - def dataset_fn(): - dataset1 = dataset_ops.Dataset.range(4) - dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)] - self._test_iterator(input_type, dataset_fn, worker_devices, - expected_values, sess) - - -class SplitDatasetBatchTest(test.TestCase): - - def testBatchDataset(self): - dataset = dataset_ops.Dataset.range(100).batch(20) - split_batch_by = 2 - result_dataset = values._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - def testMapAndBatchDataset(self): - dataset = dataset_ops.Dataset.range(100) - dataset = dataset.apply(batching.map_and_batch(lambda x: x, 20)) - split_batch_by = 2 - result_dataset = values._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - def testPrefetchDataset(self): - dataset = dataset_ops.Dataset.range(100).batch(20).prefetch(1) - split_batch_by = 2 - result_dataset = values._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - class MirroredVariableTest(test.TestCase, parameterized.TestCase): config = config_pb2.ConfigProto() @@ -813,7 +367,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): v = variable_scope.get_variable( name="v", initializer=[1.], use_resource=True) device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",)) - mirrored = values.MirroredVariable(device_map, (v,), + mirrored = values.MirroredVariable(None, device_map, (v,), variable_scope.VariableAggregation.MEAN) self.assertEqual(v.name, mirrored.name) @@ -952,7 +506,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) mirrored = values.MirroredVariable( - values.ReplicaDeviceMap(("/device:GPU:0",)), (v,), + distribution, values.ReplicaDeviceMap(("/device:GPU:0",)), (v,), variable_scope.VariableAggregation.MEAN) sess.run(variables_lib.global_variables_initializer()) sess.run({"complicated": mirrored}) @@ -961,14 +515,14 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): _devices = ("/device:GPU:0", "/device:CPU:0") -def _make_replica_local(method): +def _make_replica_local(method, strategy=None): device_map = values.ReplicaDeviceMap(_devices) v = [] for d, n, init in zip(_devices, ["v", "v/replica"], [1., 2.]): with ops.device(d): v.append(variable_scope.get_variable( name=n, initializer=init, use_resource=True)) - replica_local = values.ReplicaLocalVariable(device_map, v, method) + replica_local = values.ReplicaLocalVariable(strategy, device_map, v, method) return v, replica_local @@ -996,7 +550,7 @@ class ReplicaLocalVariablePropertiesTest(test.TestCase): name="v", initializer=[1.], use_resource=True) device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",)) replica_local = values.ReplicaLocalVariable( - device_map, (v,), variable_scope.VariableAggregation.MEAN) + None, device_map, (v,), variable_scope.VariableAggregation.MEAN) self.assertEqual(v.name, replica_local.name) self.assertEqual(v.dtype, replica_local.dtype) @@ -1043,7 +597,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution): with self.cached_session() as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.SUM) + variable_scope.VariableAggregation.SUM, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) @@ -1066,7 +620,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): with self.cached_session() as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.MEAN) + variable_scope.VariableAggregation.MEAN, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) @@ -1086,7 +640,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.MEAN) + variable_scope.VariableAggregation.MEAN, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) @@ -1102,7 +656,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): def _save_replica_local_sum(self, distribution): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: - v, replica_local = _make_replica_local("sum") + v, replica_local = _make_replica_local("sum", distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [1.5, 2.]) @@ -1149,7 +703,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.MEAN) + variable_scope.VariableAggregation.MEAN, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [7., 8.]) @@ -1164,7 +718,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.SUM) + variable_scope.VariableAggregation.SUM, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [7., 8.]) diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 452628257ea96713453bf2aa32b5baa9d6d0cb86..1006dfac49f36baa7cf5136f6f2982e3fd965298 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -249,9 +249,9 @@ class InverseGamma(distribution.Distribution): `self.allow_nan_stats` is `False`, an exception will be raised rather than returning `NaN`.""") def _variance(self): - var = (math_ops.square(self.rate) - / math_ops.square(self.concentration - 1.) - / (self.concentration - 2.)) + var = ( + math_ops.square(self.rate) / math_ops.squared_difference( + self.concentration, 1.) / (self.concentration - 2.)) if self.allow_nan_stats: nan = array_ops.fill( self.batch_shape_tensor(), diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index 257d02057ae0d280074559aa9e97725bf5cc3fd0..78ab155896cfeda4dd259a8529f4b1f77a12cf0b 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -200,13 +200,6 @@ class IteratorTest(test.TestCase): y = math_ops.add(x, x) self.assertAllEqual([0., 2.], y.numpy()) - def testGpuDefinedDataset(self): - with ops.device(test.gpu_device_name()): - ds = Dataset.from_tensors([0., 1.]) - for x in ds: - y = math_ops.add(x, x) - self.assertAllEqual([0., 2.], y.numpy()) - def testOverrideThreadPool(self): def get_thread_id(_): diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index 97c299a911c9180bf69faa0fa46527e80eada790..3e0881754c750f4d36e2e4dd8b80835b031c658c 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -6,16 +6,16 @@ package(default_visibility = ["//tensorflow:internal"]) py_library( name = "examples_pip", deps = [ - "//tensorflow/contrib/eager/python/examples/densenet", - "//tensorflow/contrib/eager/python/examples/gan:mnist", + "//tensorflow/contrib/eager/python/examples/densenet:densenet_lib", + "//tensorflow/contrib/eager/python/examples/gan:mnist_lib", "//tensorflow/contrib/eager/python/examples/l2hmc", "//tensorflow/contrib/eager/python/examples/l2hmc:neural_nets", - "//tensorflow/contrib/eager/python/examples/linear_regression", + "//tensorflow/contrib/eager/python/examples/linear_regression:linear_regression_lib", "//tensorflow/contrib/eager/python/examples/resnet50", "//tensorflow/contrib/eager/python/examples/revnet", "//tensorflow/contrib/eager/python/examples/revnet:config", - "//tensorflow/contrib/eager/python/examples/rnn_colorbot", - "//tensorflow/contrib/eager/python/examples/rnn_ptb", + "//tensorflow/contrib/eager/python/examples/rnn_colorbot:rnn_colorbot_lib", + "//tensorflow/contrib/eager/python/examples/rnn_ptb:rnn_ptb_lib", "//tensorflow/contrib/eager/python/examples/spinn:data", ], ) diff --git a/tensorflow/contrib/eager/python/examples/densenet/BUILD b/tensorflow/contrib/eager/python/examples/densenet/BUILD index e2154fcc5fcf774dcd52285d9442dfd5073a4992..fbb5daf230bb79f08a3d071062ddc0e8507ab324 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/BUILD +++ b/tensorflow/contrib/eager/python/examples/densenet/BUILD @@ -9,6 +9,13 @@ py_binary( name = "densenet", srcs = ["densenet.py"], srcs_version = "PY2AND3", + deps = [":densenet_lib"], +) + +py_library( + name = "densenet_lib", + srcs = ["densenet.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", @@ -17,33 +24,37 @@ py_binary( cuda_py_test( name = "densenet_test", - size = "large", + size = "medium", srcs = ["densenet_test.py"], additional_deps = [ ":densenet", "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "no_pip", "optonly", + "oss_serial", ], ) cuda_py_test( name = "densenet_graph_test", - size = "large", + size = "medium", srcs = ["densenet_graph_test.py"], additional_deps = [ ":densenet", "//third_party/py/numpy", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "no_pip", "noasan", "nomsan", "notsan", "optonly", + "oss_serial", ], ) diff --git a/tensorflow/contrib/eager/python/examples/gan/BUILD b/tensorflow/contrib/eager/python/examples/gan/BUILD index d64c8eb9ce122fa277567b2fbc632abfbc72df64..d99a519112787bad664232983208279cfb4d0036 100644 --- a/tensorflow/contrib/eager/python/examples/gan/BUILD +++ b/tensorflow/contrib/eager/python/examples/gan/BUILD @@ -9,6 +9,13 @@ py_binary( name = "mnist", srcs = ["mnist.py"], srcs_version = "PY2AND3", + deps = [":mnist_lib"], +) + +py_library( + name = "mnist_lib", + srcs = ["mnist.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", @@ -20,7 +27,7 @@ cuda_py_test( name = "mnist_test", srcs = ["mnist_test.py"], additional_deps = [ - ":mnist", + ":mnist_lib", "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], @@ -30,7 +37,7 @@ cuda_py_test( name = "mnist_graph_test", srcs = ["mnist_graph_test.py"], additional_deps = [ - ":mnist", + ":mnist_lib", "//third_party/py/numpy", "//tensorflow:tensorflow_py", ], diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb index 1a08cc0fd06516be4af5c2b0b46a3ffcf9101e95..e1a02db76f705414a34d232022f50124a5a6a3ed 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb @@ -13,11 +13,13 @@ "\n", "# Convolutional VAE: An example with tf.keras and eager\n", "\n", + "This example has moved:\n", + "\n", "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb\"\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/cvae.ipynb\"\u003e\n", " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/cvae.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] }, { @@ -28,604 +30,14 @@ }, "source": [ "![evolution of output during training](https://tensorflow.org/images/autoencoders/cvae.gif)\n", - "\n", - "This notebook demonstrates how to generate images of handwritten digits using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager) by training a Variational Autoencoder. (VAE, [[1]](https://arxiv.org/abs/1312.6114), [[2]](https://arxiv.org/abs/1401.4082)).\n", "\n" ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "P-JuIu2N_SQf" - }, - "outputs": [], - "source": [ - "# to generate gifs\n", - "!pip install imageio" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "e1_Y75QXJS6h" - }, - "source": [ - "## Import TensorFlow and enable Eager execution" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "YfIk2es3hJEd" - }, - "outputs": [], - "source": [ - "from __future__ import absolute_import, division, print_function\n", - "\n", - "# Import TensorFlow \u003e= 1.9 and enable eager execution\n", - "import tensorflow as tf\n", - "tfe = tf.contrib.eager\n", - "tf.enable_eager_execution()\n", - "\n", - "import os\n", - "import time\n", - "import numpy as np\n", - "import glob\n", - "import matplotlib.pyplot as plt\n", - "import PIL\n", - "import imageio\n", - "from IPython import display" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "iYn4MdZnKCey" - }, - "source": [ - "## Load the MNIST dataset\n", - "Each MNIST image is originally a vector of 784 integers, each of which is between 0-255 and represents the intensity of a pixel. We model each pixel with a Bernoulli distribution in our model, and we statically binarize the dataset." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "a4fYMGxGhrna" - }, - "outputs": [], - "source": [ - "(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "NFC2ghIdiZYE" - }, - "outputs": [], - "source": [ - "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", - "test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype('float32')\n", - "\n", - "# Normalizing the images to the range of [0., 1.]\n", - "train_images /= 255.\n", - "test_images /= 255.\n", - "\n", - "# Binarization\n", - "train_images[train_images \u003e= .5] = 1.\n", - "train_images[train_images \u003c .5] = 0.\n", - "test_images[test_images \u003e= .5] = 1.\n", - "test_images[test_images \u003c .5] = 0." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "S4PIDhoDLbsZ" - }, - "outputs": [], - "source": [ - "TRAIN_BUF = 60000\n", - "BATCH_SIZE = 100\n", - "\n", - "TEST_BUF = 10000" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "PIGN6ouoQxt3" - }, - "source": [ - "## Use *tf.data* to create batches and shuffle the dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "-yKCCQOoJ7cn" - }, - "outputs": [], - "source": [ - "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(TRAIN_BUF).batch(BATCH_SIZE)\n", - "test_dataset = tf.data.Dataset.from_tensor_slices(test_images).shuffle(TEST_BUF).batch(BATCH_SIZE)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "THY-sZMiQ4UV" - }, - "source": [ - "## Wire up the generative and inference network with *tf.keras.Sequential*\n", - "\n", - "In our VAE example, we use two small ConvNets for the generative and inference network. Since these neural nets are small, we use `tf.keras.Sequential` to simplify our code. Let $x$ and $z$ denote the observation and latent variable respectively in the following descriptions. \n", - "\n", - "### Generative Network\n", - "This defines the generative model which takes a latent encoding as input, and outputs the parameters for a conditional distribution of the observation, i.e. $p(x|z)$. Additionally, we use a unit Gaussian prior $p(z)$ for the latent variable.\n", - "\n", - "### Inference Network\n", - "This defines an approximate posterior distribution $q(z|x)$, which takes as input an observation and outputs a set of parameters for the conditional distribution of the latent representation. In this example, we simply model this distribution as a diagonal Gaussian. In this case, the inference network outputs the mean and log-variance parameters of a factorized Gaussian (log-variance instead of the variance directly is for numerical stability).\n", - "\n", - "### Reparameterization Trick\n", - "During optimization, we can sample from $q(z|x)$ by first sampling from a unit Gaussian, and then multiplying by the standard deviation and adding the mean. This ensures the gradients could pass through the sample to the inference network parameters.\n", - "\n", - "### Network architecture\n", - "For the inference network, we use two convolutional layers followed by a fully-connected layer. In the generative network, we mirror this architecture by using a fully-connected layer followed by three convolution transpose layers (a.k.a. deconvolutional layers in some contexts). Note, it's common practice to avoid using batch normalization when training VAEs, since the additional stochasticity due to using mini-batches may aggravate instability on top of the stochasticity from sampling." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "VGLbvBEmjK0a" - }, - "outputs": [], - "source": [ - "class CVAE(tf.keras.Model):\n", - " def __init__(self, latent_dim):\n", - " super(CVAE, self).__init__()\n", - " self.latent_dim = latent_dim\n", - " self.inference_net = tf.keras.Sequential(\n", - " [\n", - " tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),\n", - " tf.keras.layers.Conv2D(\n", - " filters=32, kernel_size=3, strides=(2, 2), activation=tf.nn.relu),\n", - " tf.keras.layers.Conv2D(\n", - " filters=64, kernel_size=3, strides=(2, 2), activation=tf.nn.relu),\n", - " tf.keras.layers.Flatten(),\n", - " # No activation\n", - " tf.keras.layers.Dense(latent_dim + latent_dim),\n", - " ]\n", - " )\n", - "\n", - " self.generative_net = tf.keras.Sequential(\n", - " [\n", - " tf.keras.layers.InputLayer(input_shape=(latent_dim,)),\n", - " tf.keras.layers.Dense(units=7*7*32, activation=tf.nn.relu),\n", - " tf.keras.layers.Reshape(target_shape=(7, 7, 32)),\n", - " tf.keras.layers.Conv2DTranspose(\n", - " filters=64,\n", - " kernel_size=3,\n", - " strides=(2, 2),\n", - " padding=\"SAME\",\n", - " activation=tf.nn.relu),\n", - " tf.keras.layers.Conv2DTranspose(\n", - " filters=32,\n", - " kernel_size=3,\n", - " strides=(2, 2),\n", - " padding=\"SAME\",\n", - " activation=tf.nn.relu),\n", - " # No activation\n", - " tf.keras.layers.Conv2DTranspose(\n", - " filters=1, kernel_size=3, strides=(1, 1), padding=\"SAME\"),\n", - " ]\n", - " )\n", - "\n", - " def sample(self, eps=None):\n", - " if eps is None:\n", - " eps = tf.random_normal(shape=(100, self.latent_dim))\n", - " return self.decode(eps, apply_sigmoid=True)\n", - "\n", - " def encode(self, x):\n", - " mean, logvar = tf.split(self.inference_net(x), num_or_size_splits=2, axis=1)\n", - " return mean, logvar\n", - "\n", - " def reparameterize(self, mean, logvar):\n", - " eps = tf.random_normal(shape=mean.shape)\n", - " return eps * tf.exp(logvar * .5) + mean\n", - "\n", - " def decode(self, z, apply_sigmoid=False):\n", - " logits = self.generative_net(z)\n", - " if apply_sigmoid:\n", - " probs = tf.sigmoid(logits)\n", - " return probs\n", - "\n", - " return logits" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "0FMYgY_mPfTi" - }, - "source": [ - "## Define the loss function and the optimizer\n", - "\n", - "VAEs train by maximizing the evidence lower bound (ELBO) on the marginal log-likelihood:\n", - "\n", - "$$\\log p(x) \\ge \\text{ELBO} = \\mathbb{E}_{q(z|x)}\\left[\\log \\frac{p(x, z)}{q(z|x)}\\right].$$\n", - "\n", - "In practice, we optimize the single sample Monte Carlo estimate of this expectation:\n", - "\n", - "$$\\log p(x| z) + \\log p(z) - \\log q(z|x),$$\n", - "where $z$ is sampled from $q(z|x)$.\n", - "\n", - "**Note**: we could also analytically compute the KL term, but here we incorporate all three terms in the Monte Carlo estimator for simplicity." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "iWCn_PVdEJZ7" - }, - "outputs": [], - "source": [ - "def log_normal_pdf(sample, mean, logvar, raxis=1):\n", - " log2pi = tf.log(2. * np.pi)\n", - " return tf.reduce_sum(\n", - " -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),\n", - " axis=raxis)\n", - "\n", - "def compute_loss(model, x):\n", - " mean, logvar = model.encode(x)\n", - " z = model.reparameterize(mean, logvar)\n", - " x_logit = model.decode(z)\n", - "\n", - " cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)\n", - " logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])\n", - " logpz = log_normal_pdf(z, 0., 0.)\n", - " logqz_x = log_normal_pdf(z, mean, logvar)\n", - " return -tf.reduce_mean(logpx_z + logpz - logqz_x)\n", - "\n", - "def compute_gradients(model, x):\n", - " with tf.GradientTape() as tape:\n", - " loss = compute_loss(model, x)\n", - " return tape.gradient(loss, model.trainable_variables), loss\n", - "\n", - "optimizer = tf.train.AdamOptimizer(1e-4)\n", - "def apply_gradients(optimizer, gradients, variables, global_step=None):\n", - " optimizer.apply_gradients(zip(gradients, variables), global_step=global_step)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Rw1fkAczTQYh" - }, - "source": [ - "## Training\n", - "\n", - "* We start by iterating over the dataset\n", - "* During each iteration, we pass the image to the encoder to obtain a set of mean and log-variance parameters of the approximate posterior $q(z|x)$\n", - "* We then apply the *reparameterization trick* to sample from $q(z|x)$\n", - "* Finally, we pass the reparameterized samples to the decoder to obtain the logits of the generative distribution $p(x|z)$\n", - "* **Note:** Since we use the dataset loaded by keras with 60k datapoints in the training set and 10k datapoints in the test set, our resulting ELBO on the test set is slightly higher than reported results in the literature which uses dynamic binarization of Larochelle's MNIST.\n", - "\n", - "## Generate Images\n", - "\n", - "* After training, it is time to generate some images\n", - "* We start by sampling a set of latent vectors from the unit Gaussian prior distribution $p(z)$\n", - "* The generator will then convert the latent sample $z$ to logits of the observation, giving a distribution $p(x|z)$\n", - "* Here we plot the probabilities of Bernoulli distributions\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "NS2GWywBbAWo" - }, - "outputs": [], - "source": [ - "epochs = 100\n", - "latent_dim = 50\n", - "num_examples_to_generate = 16\n", - "\n", - "# keeping the random vector constant for generation (prediction) so\n", - "# it will be easier to see the improvement.\n", - "random_vector_for_generation = tf.random_normal(\n", - " shape=[num_examples_to_generate, latent_dim])\n", - "model = CVAE(latent_dim)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "RmdVsmvhPxyy" - }, - "outputs": [], - "source": [ - "def generate_and_save_images(model, epoch, test_input):\n", - " predictions = model.sample(test_input)\n", - " fig = plt.figure(figsize=(4,4))\n", - "\n", - " for i in range(predictions.shape[0]):\n", - " plt.subplot(4, 4, i+1)\n", - " plt.imshow(predictions[i, :, :, 0], cmap='gray')\n", - " plt.axis('off')\n", - "\n", - " # tight_layout minimizes the overlap between 2 sub-plots\n", - " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "2M7LmLtGEMQJ" - }, - "outputs": [], - "source": [ - "generate_and_save_images(model, 0, random_vector_for_generation)\n", - "\n", - "for epoch in range(1, epochs + 1):\n", - " start_time = time.time()\n", - " for train_x in train_dataset:\n", - " gradients, loss = compute_gradients(model, train_x)\n", - " apply_gradients(optimizer, gradients, model.trainable_variables)\n", - " end_time = time.time()\n", - "\n", - " if epoch % 1 == 0:\n", - " loss = tfe.metrics.Mean()\n", - " for test_x in test_dataset:\n", - " loss(compute_loss(model, test_x))\n", - " elbo = -loss.result()\n", - " display.clear_output(wait=False)\n", - " print('Epoch: {}, Test set ELBO: {}, '\n", - " 'time elapse for current epoch {}'.format(epoch,\n", - " elbo,\n", - " end_time - start_time))\n", - " generate_and_save_images(\n", - " model, epoch, random_vector_for_generation)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "P4M_vIbUi7c0" - }, - "source": [ - "### Display an image using the epoch number" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "WfO5wCdclHGL" - }, - "outputs": [], - "source": [ - "def display_image(epoch_no):\n", - " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "5x3q9_Oe5q0A" - }, - "outputs": [], - "source": [ - "display_image(epochs) # Display images" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "NywiH3nL8guF" - }, - "source": [ - "### Generate a GIF of all the saved images." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "IGKQgENQ8lEI" - }, - "outputs": [], - "source": [ - "with imageio.get_writer('cvae.gif', mode='I') as writer:\n", - " filenames = glob.glob('image*.png')\n", - " filenames = sorted(filenames)\n", - " last = -1\n", - " for i,filename in enumerate(filenames):\n", - " frame = 2*(i**0.5)\n", - " if round(frame) \u003e round(last):\n", - " last = frame\n", - " else:\n", - " continue\n", - " image = imageio.imread(filename)\n", - " writer.append_data(image)\n", - " image = imageio.imread(filename)\n", - " writer.append_data(image)\n", - " \n", - "# this is a hack to display the gif inside the notebook\n", - "os.system('cp cvae.gif cvae.gif.png')" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "uV0yiKpzNP1b" - }, - "outputs": [], - "source": [ - "display.Image(filename=\"cvae.gif.png\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "yQXO_dlXkKsT" - }, - "source": [ - "To downlod the animation from Colab uncomment the code below:" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "4fSJS3m5HLFM" - }, - "outputs": [], - "source": [ - "#from google.colab import files\n", - "#files.download('cvae.gif')" - ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], - "default_view": {}, "name": "cvae.ipynb", "private_outputs": true, "provenance": [ @@ -635,8 +47,7 @@ } ], "toc_visible": true, - "version": "0.3.2", - "views": {} + "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb index 78fcd397087fd1fd64aebed7ac3b5c6b2f45c450..53767058838459e56215d286e9f8f8eb66287147 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb @@ -1,26 +1,11 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "dcgan.ipynb", - "version": "0.3.2", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python2", - "display_name": "Python 2" - }, - "accelerator": "GPU" - }, "cells": [ { + "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "0TD5ZrvEMbhZ" }, - "cell_type": "markdown", "source": [ "**Copyright 2018 The TensorFlow Authors**.\n", "\n", @@ -28,851 +13,39 @@ "\n", "# Generating Handwritten Digits with DCGAN\n", "\n", - "
\n", - "\n", - " Run in Google Colab \n", - "\n", - "View source on GitHub
" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "ITZuApL56Mny" - }, - "cell_type": "markdown", - "source": [ - "This tutorial demonstrates how to generate images of handwritten digits using a Deep Convolutional Generative Adversarial Network ([DCGAN](https://arxiv.org/pdf/1511.06434.pdf)). The code is written in [tf.keras](https://www.tensorflow.org/programmers_guide/keras) with [eager execution](https://www.tensorflow.org/programmers_guide/eager) enabled. " - ] - }, - { - "metadata": { - "colab_type": "toc", - "id": "x2McrO9bMyLN" - }, - "cell_type": "markdown", - "source": [ - ">[Generating Handwritten Digits with DCGAN](#scrollTo=0TD5ZrvEMbhZ)\n", - "\n", - ">>[What are GANs?](#scrollTo=2MbKJY38Puy9)\n", - "\n", - ">>>[Import TensorFlow and enable eager execution](#scrollTo=e1_Y75QXJS6h)\n", - "\n", - ">>>[Load the dataset](#scrollTo=iYn4MdZnKCey)\n", - "\n", - ">>>[Use tf.data to create batches and shuffle the dataset](#scrollTo=PIGN6ouoQxt3)\n", - "\n", - ">>[Create the models](#scrollTo=THY-sZMiQ4UV)\n", - "\n", - ">>>[The Generator Model](#scrollTo=-tEyxE-GMC48)\n", - "\n", - ">>>[The Discriminator model](#scrollTo=D0IKnaCtg6WE)\n", - "\n", - ">>[Define the loss functions and the optimizer](#scrollTo=0FMYgY_mPfTi)\n", - "\n", - ">>>[Generator loss](#scrollTo=Jd-3GCUEiKtv)\n", - "\n", - ">>>[Discriminator loss](#scrollTo=PKY_iPSPNWoj)\n", - "\n", - ">>[Set up GANs for Training](#scrollTo=Rw1fkAczTQYh)\n", - "\n", - ">>[Train the GANs](#scrollTo=dZrd4CdjR-Fp)\n", - "\n", - ">>[Generated images](#scrollTo=P4M_vIbUi7c0)\n", + "This example has moved.\n", "\n", - ">>[Learn more about GANs](#scrollTo=k6qC-SbjK0yW)\n", - "\n" + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/dcgan.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/blob/master/site/en/r2/tutorials/generative/dcgan.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] }, { + "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "2MbKJY38Puy9" }, - "cell_type": "markdown", "source": [ - "## What are GANs?\n", - "GANs, or [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661), are a framework for estimating generative models. Two models are trained simultaneously by an adversarial process: a Generator, which is responsible for generating data (say, images), and a Discriminator, which is responsible for estimating the probability that an image was drawn from the training data (the image is real), or was produced by the Generator (the image is fake). During training, the Generator becomes progressively better at generating images, until the Discriminator is no longer able to distinguish real images from fake. \n", - "\n", - "![alt text](https://github.com/margaretmz/tensorflow/blob/margaret-dcgan/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png?raw=1)\n", - "\n", - "We will demonstrate this process end-to-end on MNIST. Below is an animation that shows a series of images produced by the Generator as it was trained for 50 epochs. Overtime, the generated images become increasingly difficult to distinguish from the training set.\n", - "\n", - "To learn more about GANs, we recommend MIT's [Intro to Deep Learning](http://introtodeeplearning.com/) course, which includes a lecture on Deep Generative Models ([video](https://youtu.be/JVb54xhEw6Y) | [slides](http://introtodeeplearning.com/materials/2018_6S191_Lecture4.pdf)). Now, let's head to the code!\n", - "\n", "![sample output](https://tensorflow.org/images/gan/dcgan.gif)" ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "dcgan.ipynb", + "provenance": [], + "version": "0.3.2" }, - { - "metadata": { - "colab_type": "code", - "id": "u_2z-B3piVsw", - "colab": {} - }, - "cell_type": "code", - "source": [ - "# Install imgeio in order to generate an animated gif showing the image generating process\n", - "!pip install imageio" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "e1_Y75QXJS6h" - }, - "cell_type": "markdown", - "source": [ - "### Import TensorFlow and enable eager execution" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "YfIk2es3hJEd", - "colab": {} - }, - "cell_type": "code", - "source": [ - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "\n", - "import glob\n", - "import imageio\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import os\n", - "import PIL\n", - "import time\n", - "\n", - "from IPython import display" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "iYn4MdZnKCey" - }, - "cell_type": "markdown", - "source": [ - "### Load the dataset\n", - "\n", - "We are going to use the MNIST dataset to train the generator and the discriminator. The generator will generate handwritten digits resembling the MNIST data." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "a4fYMGxGhrna", - "colab": {} - }, - "cell_type": "code", - "source": [ - "(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "NFC2ghIdiZYE", - "colab": {} - }, - "cell_type": "code", - "source": [ - "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", - "train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "S4PIDhoDLbsZ", - "colab": {} - }, - "cell_type": "code", - "source": [ - "BUFFER_SIZE = 60000\n", - "BATCH_SIZE = 256" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "PIGN6ouoQxt3" - }, - "cell_type": "markdown", - "source": [ - "### Use tf.data to create batches and shuffle the dataset" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "-yKCCQOoJ7cn", - "colab": {} - }, - "cell_type": "code", - "source": [ - "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "THY-sZMiQ4UV" - }, - "cell_type": "markdown", - "source": [ - "## Create the models\n", - "\n", - "We will use tf.keras [Sequential API](https://www.tensorflow.org/guide/keras#sequential_model) to define the generator and discriminator models." - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "-tEyxE-GMC48" - }, - "cell_type": "markdown", - "source": [ - "### The Generator Model\n", - "\n", - "The generator is responsible for creating convincing images that are good enough to fool the discriminator. The network architecture for the generator consists of [Conv2DTranspose](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2DTranspose) (Upsampling) layers. We start with a fully connected layer and upsample the image two times in order to reach the desired image size of 28x28x1. We increase the width and height, and reduce the depth as we move through the layers in the network. We use [Leaky ReLU](https://www.tensorflow.org/api_docs/python/tf/keras/layers/LeakyReLU) activation for each layer except for the last one where we use a tanh activation." - ] - }, - { - "metadata": { - "id": "6bpTcDqoLWjY", - "colab_type": "code", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def make_generator_model():\n", - " model = tf.keras.Sequential()\n", - " model.add(tf.keras.layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))\n", - " model.add(tf.keras.layers.BatchNormalization())\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - " \n", - " model.add(tf.keras.layers.Reshape((7, 7, 256)))\n", - " assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size\n", - " \n", - " model.add(tf.keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))\n", - " assert model.output_shape == (None, 7, 7, 128) \n", - " model.add(tf.keras.layers.BatchNormalization())\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - "\n", - " model.add(tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))\n", - " assert model.output_shape == (None, 14, 14, 64) \n", - " model.add(tf.keras.layers.BatchNormalization())\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - "\n", - " model.add(tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))\n", - " assert model.output_shape == (None, 28, 28, 1)\n", - " \n", - " return model" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "D0IKnaCtg6WE" - }, - "cell_type": "markdown", - "source": [ - "### The Discriminator model\n", - "\n", - "The discriminator is responsible for distinguishing fake images from real images. It's similar to a regular CNN-based image classifier." - ] - }, - { - "metadata": { - "id": "dw2tPLmk2pEP", - "colab_type": "code", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def make_discriminator_model():\n", - " model = tf.keras.Sequential()\n", - " model.add(tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'))\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - " model.add(tf.keras.layers.Dropout(0.3))\n", - " \n", - " model.add(tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - " model.add(tf.keras.layers.Dropout(0.3))\n", - " \n", - " model.add(tf.keras.layers.Flatten())\n", - " model.add(tf.keras.layers.Dense(1))\n", - " \n", - " return model" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "gDkA05NE6QMs", - "colab": {} - }, - "cell_type": "code", - "source": [ - "generator = make_generator_model()\n", - "discriminator = make_discriminator_model()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "0FMYgY_mPfTi" - }, - "cell_type": "markdown", - "source": [ - "## Define the loss functions and the optimizer\n", - "\n", - "Let's define the loss functions and the optimizers for the generator and the discriminator.\n" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "Jd-3GCUEiKtv" - }, - "cell_type": "markdown", - "source": [ - "### Generator loss\n", - "The generator loss is a sigmoid cross entropy loss of the generated images and an array of ones, since the generator is trying to generate fake images that resemble the real images." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "90BIcCKcDMxz", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def generator_loss(generated_output):\n", - " return tf.losses.sigmoid_cross_entropy(tf.ones_like(generated_output), generated_output)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "PKY_iPSPNWoj" - }, - "cell_type": "markdown", - "source": [ - "### Discriminator loss\n", - "\n", - "The discriminator loss function takes two inputs: real images, and generated images. Here is how to calculate the discriminator loss:\n", - "1. Calculate real_loss which is a sigmoid cross entropy loss of the real images and an array of ones (since these are the real images).\n", - "2. Calculate generated_loss which is a sigmoid cross entropy loss of the generated images and an array of zeros (since these are the fake images).\n", - "3. Calculate the total_loss as the sum of real_loss and generated_loss." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "wkMNfBWlT-PV", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def discriminator_loss(real_output, generated_output):\n", - " # [1,1,...,1] with real output since it is true and we want our generated examples to look like it\n", - " real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(real_output), logits=real_output)\n", - "\n", - " # [0,0,...,0] with generated images since they are fake\n", - " generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(generated_output), logits=generated_output)\n", - "\n", - " total_loss = real_loss + generated_loss\n", - "\n", - " return total_loss" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "MgIc7i0th_Iu" - }, - "cell_type": "markdown", - "source": [ - "The discriminator and the generator optimizers are different since we will train two networks separately." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "iWCn_PVdEJZ7", - "colab": {} - }, - "cell_type": "code", - "source": [ - "generator_optimizer = tf.train.AdamOptimizer(1e-4)\n", - "discriminator_optimizer = tf.train.AdamOptimizer(1e-4)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "mWtinsGDPJlV" - }, - "cell_type": "markdown", - "source": [ - "**Checkpoints (Object-based saving)**" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "CA1w-7s2POEy", - "colab": {} - }, - "cell_type": "code", - "source": [ - "checkpoint_dir = './training_checkpoints'\n", - "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", - "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n", - " discriminator_optimizer=discriminator_optimizer,\n", - " generator=generator,\n", - " discriminator=discriminator)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "Rw1fkAczTQYh" - }, - "cell_type": "markdown", - "source": [ - "## Set up GANs for Training\n", - "\n" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "5QC5BABamh_c" - }, - "cell_type": "markdown", - "source": [ - "Now it's time to put together the generator and discriminator to set up the Generative Adversarial Networks, as you see in the diagam at the beginning of the tutorial." - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "Ff6oN6PZX27n" - }, - "cell_type": "markdown", - "source": [ - "**Define training parameters**" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "NS2GWywBbAWo", - "colab": {} - }, - "cell_type": "code", - "source": [ - "EPOCHS = 50\n", - "noise_dim = 100\n", - "num_examples_to_generate = 16\n", - "\n", - "# We'll re-use this random vector used to seed the generator so\n", - "# it will be easier to see the improvement over time.\n", - "random_vector_for_generation = tf.random_normal([num_examples_to_generate,\n", - " noise_dim])" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "jylSonrqSWfi" - }, - "cell_type": "markdown", - "source": [ - "**Define training method**\n", - "\n", - "We start by iterating over the dataset. The generator is given a random vector as an input which is processed to output an image looking like a handwritten digit. The discriminator is then shown the real MNIST images as well as the generated images.\n", - "\n", - "Next, we calculate the generator and the discriminator loss. Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables." - ] - }, - { - "metadata": { - "id": "3t5ibNo05jCB", - "colab_type": "code", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def train_step(images):\n", - " # generating noise from a normal distribution\n", - " noise = tf.random_normal([BATCH_SIZE, noise_dim])\n", - " \n", - " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", - " generated_images = generator(noise, training=True)\n", - " \n", - " real_output = discriminator(images, training=True)\n", - " generated_output = discriminator(generated_images, training=True)\n", - " \n", - " gen_loss = generator_loss(generated_output)\n", - " disc_loss = discriminator_loss(real_output, generated_output)\n", - " \n", - " gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)\n", - " gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)\n", - " \n", - " generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))\n", - " discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "6TSZgwc2BUQ-" - }, - "cell_type": "markdown", - "source": [ - "\n", - "This model takes about ~30 seconds per epoch to train on a single Tesla K80 on Colab, as of October 2018. \n", - "\n", - "Eager execution can be slower than executing the equivalent graph as it can't benefit from whole-program optimizations on the graph, and also incurs overheads of interpreting Python code. By using [tf.contrib.eager.defun](https://www.tensorflow.org/api_docs/python/tf/contrib/eager/defun) to create graph functions, we get a ~20 secs/epoch performance boost (from ~50 secs/epoch down to ~30 secs/epoch). This way we get the best of both eager execution (easier for debugging) and graph mode (better performance)." - ] - }, - { - "metadata": { - "id": "Iwya07_j5p2A", - "colab_type": "code", - "colab": {} - }, - "cell_type": "code", - "source": [ - "train_step = tf.contrib.eager.defun(train_step)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "2M7LmLtGEMQJ", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def train(dataset, epochs): \n", - " for epoch in range(epochs):\n", - " start = time.time()\n", - " \n", - " for images in dataset:\n", - " train_step(images)\n", - "\n", - " display.clear_output(wait=True)\n", - " generate_and_save_images(generator,\n", - " epoch + 1,\n", - " random_vector_for_generation)\n", - " \n", - " # saving (checkpoint) the model every 15 epochs\n", - " if (epoch + 1) % 15 == 0:\n", - " checkpoint.save(file_prefix = checkpoint_prefix)\n", - " \n", - " print ('Time taken for epoch {} is {} sec'.format(epoch + 1,\n", - " time.time()-start))\n", - " # generating after the final epoch\n", - " display.clear_output(wait=True)\n", - " generate_and_save_images(generator,\n", - " epochs,\n", - " random_vector_for_generation)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "2aFF7Hk3XdeW" - }, - "cell_type": "markdown", - "source": [ - "**Generate and save images**\n", - "\n" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "RmdVsmvhPxyy", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def generate_and_save_images(model, epoch, test_input):\n", - " # make sure the training parameter is set to False because we\n", - " # don't want to train the batchnorm layer when doing inference.\n", - " predictions = model(test_input, training=False)\n", - "\n", - " fig = plt.figure(figsize=(4,4))\n", - " \n", - " for i in range(predictions.shape[0]):\n", - " plt.subplot(4, 4, i+1)\n", - " plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')\n", - " plt.axis('off')\n", - " \n", - " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", - " plt.show()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "dZrd4CdjR-Fp" - }, - "cell_type": "markdown", - "source": [ - "## Train the GANs\n", - "We will call the train() method defined above to train the generator and discriminator simultaneously. Note, training GANs can be tricky. It's important that the generator and discriminator do not overpower each other (e.g., that they train at a similar rate).\n", - "\n", - "At the beginning of the training, the generated images look like random noise. As training progresses, you can see the generated digits look increasingly real. After 50 epochs, they look very much like the MNIST digits." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "Ly3UN0SLLY2l", - "colab": {} - }, - "cell_type": "code", - "source": [ - "%%time\n", - "train(train_dataset, EPOCHS)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "rfM4YcPVPkNO" - }, - "cell_type": "markdown", - "source": [ - "**Restore the latest checkpoint**" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "XhXsd0srPo8c", - "colab": {} - }, - "cell_type": "code", - "source": [ - "# restoring the latest checkpoint in checkpoint_dir\n", - "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "P4M_vIbUi7c0" - }, - "cell_type": "markdown", - "source": [ - "## Generated images \n" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "mLskt7EfXAjr" - }, - "cell_type": "markdown", - "source": [ - "\n", - "After training, its time to generate some images! \n", - "The last step is to plot the generated images and voila!\n" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "WfO5wCdclHGL", - "colab": {} - }, - "cell_type": "code", - "source": [ - "# Display a single image using the epoch number\n", - "def display_image(epoch_no):\n", - " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "5x3q9_Oe5q0A", - "colab": {} - }, - "cell_type": "code", - "source": [ - "display_image(EPOCHS)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "NywiH3nL8guF" - }, - "cell_type": "markdown", - "source": [ - "**Generate a GIF of all the saved images**\n", - "\n", - "We will use imageio to create an animated gif using all the images saved during training." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "IGKQgENQ8lEI", - "colab": {} - }, - "cell_type": "code", - "source": [ - "with imageio.get_writer('dcgan.gif', mode='I') as writer:\n", - " filenames = glob.glob('image*.png')\n", - " filenames = sorted(filenames)\n", - " last = -1\n", - " for i,filename in enumerate(filenames):\n", - " frame = 2*(i**0.5)\n", - " if round(frame) > round(last):\n", - " last = frame\n", - " else:\n", - " continue\n", - " image = imageio.imread(filename)\n", - " writer.append_data(image)\n", - " image = imageio.imread(filename)\n", - " writer.append_data(image)\n", - " \n", - "# this is a hack to display the gif inside the notebook\n", - "os.system('cp dcgan.gif dcgan.gif.png')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "cGhC3-fMWSwl" - }, - "cell_type": "markdown", - "source": [ - "Display the animated gif with all the mages generated during the training of GANs." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "uV0yiKpzNP1b", - "colab": {} - }, - "cell_type": "code", - "source": [ - "display.Image(filename=\"dcgan.gif.png\")" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "6EEG-wePkmJQ" - }, - "cell_type": "markdown", - "source": [ - "**Download the animated gif**\n", - "\n", - "Uncomment the code below to download an animated gif from Colab." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "4UJjSnIMOzOJ", - "colab": {} - }, - "cell_type": "code", - "source": [ - "#from google.colab import files\n", - "#files.download('dcgan.gif')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "k6qC-SbjK0yW" - }, - "cell_type": "markdown", - "source": [ - "## Learn more about GANs\n" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "xjjkT9KAK6H7" - }, - "cell_type": "markdown", - "source": [ - "We hope this tutorial was helpful! As a next step, you might like to experiment with a different dataset, for example the Large-scale Celeb Faces Attributes (CelebA) dataset [available on Kaggle](https://www.kaggle.com/jessicali9530/celeba-dataset/home).\n", - "\n", - "To learn more about GANs:\n", - "\n", - "* Check out MIT's lecture (linked above), or [this](http://cs231n.stanford.edu/slides/2018/cs231n_2018_lecture12.pdf) lecture form Stanford's CS231n. \n", - "\n", - "* We also recommend the [CVPR 2018 Tutorial on GANs](https://sites.google.com/view/cvpr2018tutorialongans/), and the [NIPS 2016 Tutorial: Generative Adversarial Networks](https://arxiv.org/abs/1701.00160).\n" - ] + "kernelspec": { + "display_name": "Python 2", + "name": "python2" } - ] -} \ No newline at end of file + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png b/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png deleted file mode 100644 index b715bd83ef117641c6429e0ac173dbe9b8d5fd88..0000000000000000000000000000000000000000 Binary files a/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png and /dev/null differ diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb index 12c5eff2b4aa901bdab52bf545e95b1e4dce7468..979772acd3f823a8cc53ab5e026946ad3bb19353 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb @@ -1,1174 +1,71 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "K2s1A9eLRPEj" - }, - "source": [ - "##### Copyright 2018 The TensorFlow Authors.\n", - "\n", - "Licensed under the Apache License, Version 2.0 (the \"License\").\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Cffg2i257iMS" - }, - "source": [ - "# Image Captioning with Attention\n", - "\n", - "
\n", - "\n", - " Run in Google Colab \n", - "\n", - "View source on GitHub
" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "QASbY_HGo4Lq" - }, - "source": [ - "Image captioning is the task of generating a caption for an image. Given an image like this:\n", - "\n", - "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n", - "\n", - "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n", - "\n", - "Our goal is to generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention-based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n", - "\n", - "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n", - "\n", - "This model architecture below is similar to [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044). \n", - "\n", - "The code uses [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager), which you can learn more about in the linked guides.\n", - "\n", - "This notebook is an end-to-end example. If you run it, it will download the [MS-COCO](http://cocodataset.org/#home) dataset, preprocess and cache a subset of the images using Inception V3, train an encoder-decoder model, and use it to generate captions on new images.\n", - "\n", - "The code requires TensorFlow version >=1.9. If you're running this in [Colab]()\n", - "\n", - "In this example, we're training on a relatively small amount of data as an example. On a single P100 GPU, this example will take about ~2 hours to train. We train on the first 30,000 captions (corresponding to about ~20,000 images depending on shuffling, as there are multiple captions per image in the dataset)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "U8l4RJ0XRPEm" - }, - "outputs": [], - "source": [ - "# Import TensorFlow and enable eager execution\n", - "# This code requires TensorFlow version >=1.9\n", - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "\n", - "# We'll generate plots of attention in order to see which parts of an image\n", - "# our model focuses on during captioning\n", - "import matplotlib.pyplot as plt\n", - "\n", - "# Scikit-learn includes many helpful utilities\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.utils import shuffle\n", - "\n", - "import re\n", - "import numpy as np\n", - "import os\n", - "import time\n", - "import json\n", - "from glob import glob\n", - "from PIL import Image\n", - "import pickle" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "b6qbGw8MRPE5" - }, - "source": [ - "## Download and prepare the MS-COCO dataset\n", - "\n", - "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code below will download and extract the dataset automatically. \n", - "\n", - "**Caution: large download ahead**. We'll use the training set, it's a 13GB file." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "krQuPYTtRPE7" - }, - "outputs": [], - "source": [ - "annotation_zip = tf.keras.utils.get_file('captions.zip', \n", - " cache_subdir=os.path.abspath('.'),\n", - " origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n", - " extract = True)\n", - "annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n", - "\n", - "name_of_zip = 'train2014.zip'\n", - "if not os.path.exists(os.path.abspath('.') + '/' + name_of_zip):\n", - " image_zip = tf.keras.utils.get_file(name_of_zip, \n", - " cache_subdir=os.path.abspath('.'),\n", - " origin = 'http://images.cocodataset.org/zips/train2014.zip',\n", - " extract = True)\n", - " PATH = os.path.dirname(image_zip)+'/train2014/'\n", - "else:\n", - " PATH = os.path.abspath('.')+'/train2014/'" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "aANEzb5WwSzg" - }, - "source": [ - "## Optionally, limit the size of the training set for faster training\n", - "For this example, we'll select a subset of 30,000 captions and use these and the corresponding images to train our model. As always, captioning quality will improve if you choose to use more data." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "4G3b8x8_RPFD" - }, - "outputs": [], - "source": [ - "# read the json file\n", - "with open(annotation_file, 'r') as f:\n", - " annotations = json.load(f)\n", - "\n", - "# storing the captions and the image name in vectors\n", - "all_captions = []\n", - "all_img_name_vector = []\n", - "\n", - "for annot in annotations['annotations']:\n", - " caption = ' ' + annot['caption'] + ' '\n", - " image_id = annot['image_id']\n", - " full_coco_image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (image_id)\n", - " \n", - " all_img_name_vector.append(full_coco_image_path)\n", - " all_captions.append(caption)\n", - "\n", - "# shuffling the captions and image_names together\n", - "# setting a random state\n", - "train_captions, img_name_vector = shuffle(all_captions,\n", - " all_img_name_vector,\n", - " random_state=1)\n", - "\n", - "# selecting the first 30000 captions from the shuffled set\n", - "num_examples = 30000\n", - "train_captions = train_captions[:num_examples]\n", - "img_name_vector = img_name_vector[:num_examples]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "mPBMgK34RPFL" - }, - "outputs": [], - "source": [ - "len(train_captions), len(all_captions)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "8cSW4u-ORPFQ" - }, - "source": [ - "## Preprocess the images using InceptionV3\n", - "Next, we will use InceptionV3 (pretrained on Imagenet) to classify each image. We will extract features from the last convolutional layer. \n", - "\n", - "First, we will need to convert the images into the format inceptionV3 expects by:\n", - "* Resizing the image to (299, 299)\n", - "* Using the [preprocess_input](https://www.tensorflow.org/api_docs/python/tf/keras/applications/inception_v3/preprocess_input) method to place the pixels in the range of -1 to 1 (to match the format of the images used to train InceptionV3)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "zXR0217aRPFR" - }, - "outputs": [], - "source": [ - "def load_image(image_path):\n", - " img = tf.read_file(image_path)\n", - " img = tf.image.decode_jpeg(img, channels=3)\n", - " img = tf.image.resize_images(img, (299, 299))\n", - " img = tf.keras.applications.inception_v3.preprocess_input(img)\n", - " return img, image_path" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "MDvIu4sXRPFV" - }, - "source": [ - "## Initialize InceptionV3 and load the pretrained Imagenet weights\n", - "\n", - "To do so, we'll create a tf.keras model where the output layer is the last convolutional layer in the InceptionV3 architecture. \n", - "* Each image is forwarded through the network and the vector that we get at the end is stored in a dictionary (image_name --> feature_vector). \n", - "* We use the last convolutional layer because we are using attention in this example. The shape of the output of this layer is ```8x8x2048```. \n", - "* We avoid doing this during training so it does not become a bottleneck. \n", - "* After all the images are passed through the network, we pickle the dictionary and save it to disk." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "RD3vW4SsRPFW" - }, - "outputs": [], - "source": [ - "image_model = tf.keras.applications.InceptionV3(include_top=False, \n", - " weights='imagenet')\n", - "new_input = image_model.input\n", - "hidden_layer = image_model.layers[-1].output\n", - "\n", - "image_features_extract_model = tf.keras.Model(new_input, hidden_layer)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "rERqlR3WRPGO" - }, - "source": [ - "## Caching the features extracted from InceptionV3\n", - "\n", - "We will pre-process each image with InceptionV3 and cache the output to disk. Caching the output in RAM would be faster but memory intensive, requiring 8 \\* 8 \\* 2048 floats per image. At the time of writing, this would exceed the memory limitations of Colab (although these may change, an instance appears to have about 12GB of memory currently). \n", - "\n", - "Performance could be improved with a more sophisticated caching strategy (e.g., by sharding the images to reduce random access disk I/O) at the cost of more code.\n", - "\n", - "This will take about 10 minutes to run in Colab with a GPU. If you'd like to see a progress bar, you could: install [tqdm](https://github.com/tqdm/tqdm) (```!pip install tqdm```), then change this line: \n", - "\n", - "```for img, path in image_dataset:``` \n", - "\n", - "to:\n", - "\n", - "```for img, path in tqdm(image_dataset):```." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "Dx_fvbVgRPGQ" - }, - "outputs": [], - "source": [ - "# getting the unique images\n", - "encode_train = sorted(set(img_name_vector))\n", - "\n", - "# feel free to change the batch_size according to your system configuration\n", - "image_dataset = tf.data.Dataset.from_tensor_slices(\n", - " encode_train).map(load_image).batch(16)\n", - "\n", - "for img, path in image_dataset:\n", - " batch_features = image_features_extract_model(img)\n", - " batch_features = tf.reshape(batch_features, \n", - " (batch_features.shape[0], -1, batch_features.shape[3]))\n", - "\n", - " for bf, p in zip(batch_features, path):\n", - " path_of_feature = p.numpy().decode(\"utf-8\")\n", - " np.save(path_of_feature, bf.numpy())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "nyqH3zFwRPFi" - }, - "source": [ - "## Preprocess and tokenize the captions\n", - "\n", - "* First, we'll tokenize the captions (e.g., by splitting on spaces). This will give us a vocabulary of all the unique words in the data (e.g., \"surfing\", \"football\", etc).\n", - "* Next, we'll limit the vocabulary size to the top 5,000 words to save memory. We'll replace all other words with the token \"UNK\" (for unknown).\n", - "* Finally, we create a word --> index mapping and vice-versa.\n", - "* We will then pad all sequences to the be same length as the longest one. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "HZfK8RhQRPFj" - }, - "outputs": [], - "source": [ - "# This will find the maximum length of any caption in our dataset\n", - "def calc_max_length(tensor):\n", - " return max(len(t) for t in tensor)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "oJGE34aiRPFo" - }, - "outputs": [], - "source": [ - "# The steps above is a general process of dealing with text processing\n", - "\n", - "# choosing the top 5000 words from the vocabulary\n", - "top_k = 5000\n", - "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k, \n", - " oov_token=\"\", \n", - " filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')\n", - "tokenizer.fit_on_texts(train_captions)\n", - "train_seqs = tokenizer.texts_to_sequences(train_captions)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "8Q44tNQVRPFt" - }, - "outputs": [], - "source": [ - "tokenizer.word_index[''] = 0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "0fpJb5ojRPFv" - }, - "outputs": [], - "source": [ - "# creating the tokenized vectors\n", - "train_seqs = tokenizer.texts_to_sequences(train_captions)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "AidglIZVRPF4" - }, - "outputs": [], - "source": [ - "# padding each vector to the max_length of the captions\n", - "# if the max_length parameter is not provided, pad_sequences calculates that automatically\n", - "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "gL0wkttkRPGA" - }, - "outputs": [], - "source": [ - "# calculating the max_length \n", - "# used to store the attention weights\n", - "max_length = calc_max_length(train_seqs)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "M3CD75nDpvTI" - }, - "source": [ - "## Split the data into training and testing" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "iS7DDMszRPGF" - }, - "outputs": [], - "source": [ - "# Create training and validation sets using 80-20 split\n", - "img_name_train, img_name_val, cap_train, cap_val = train_test_split(img_name_vector, \n", - " cap_vector, \n", - " test_size=0.2, \n", - " random_state=0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "XmViPkRFRPGH" - }, - "outputs": [], - "source": [ - "len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "uEWM9xrYcg45" - }, - "source": [ - "## Our images and captions are ready! Next, let's create a tf.data dataset to use for training our model.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "Q3TnZ1ToRPGV" - }, - "outputs": [], - "source": [ - "# feel free to change these parameters according to your system's configuration\n", - "\n", - "BATCH_SIZE = 64\n", - "BUFFER_SIZE = 1000\n", - "embedding_dim = 256\n", - "units = 512\n", - "vocab_size = len(tokenizer.word_index)\n", - "# shape of the vector extracted from InceptionV3 is (64, 2048)\n", - "# these two variables represent that\n", - "features_shape = 2048\n", - "attention_features_shape = 64" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "SmZS2N0bXG3T" - }, - "outputs": [], - "source": [ - "# loading the numpy files \n", - "def map_func(img_name, cap):\n", - " img_tensor = np.load(img_name.decode('utf-8')+'.npy')\n", - " return img_tensor, cap" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "FDF_Nm3tRPGZ" - }, - "outputs": [], - "source": [ - "dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))\n", - "\n", - "# using map to load the numpy files in parallel\n", - "# NOTE: Be sure to set num_parallel_calls to the number of CPU cores you have\n", - "# https://www.tensorflow.org/api_docs/python/tf/py_func\n", - "dataset = dataset.map(lambda item1, item2: tf.py_func(\n", - " map_func, [item1, item2], [tf.float32, tf.int32]), num_parallel_calls=8)\n", - "\n", - "# shuffling and batching\n", - "dataset = dataset.shuffle(BUFFER_SIZE)\n", - "# https://www.tensorflow.org/api_docs/python/tf/contrib/data/batch_and_drop_remainder\n", - "dataset = dataset.batch(BATCH_SIZE)\n", - "dataset = dataset.prefetch(1)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "nrvoDphgRPGd" - }, - "source": [ - "## Model\n", - "\n", - "Fun fact, the decoder below is identical to the one in the example for [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", - "\n", - "The model architecture is inspired by the [Show, Attend and Tell](https://arxiv.org/pdf/1502.03044.pdf) paper.\n", - "\n", - "* In this example, we extract the features from the lower convolutional layer of InceptionV3 giving us a vector of shape (8, 8, 2048). \n", - "* We squash that to a shape of (64, 2048).\n", - "* This vector is then passed through the CNN Encoder(which consists of a single Fully connected layer).\n", - "* The RNN(here GRU) attends over the image to predict the next word." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "AAppCGLKRPGd" - }, - "outputs": [], - "source": [ - "def gru(units):\n", - " # If you have a GPU, we recommend using the CuDNNGRU layer (it provides a \n", - " # significant speedup).\n", - " if tf.test.is_gpu_available():\n", - " return tf.keras.layers.CuDNNGRU(units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_initializer='glorot_uniform')\n", - " else:\n", - " return tf.keras.layers.GRU(units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_activation='sigmoid', \n", - " recurrent_initializer='glorot_uniform')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "ja2LFTMSdeV3" - }, - "outputs": [], - "source": [ - "class BahdanauAttention(tf.keras.Model):\n", - " def __init__(self, units):\n", - " super(BahdanauAttention, self).__init__()\n", - " self.W1 = tf.keras.layers.Dense(units)\n", - " self.W2 = tf.keras.layers.Dense(units)\n", - " self.V = tf.keras.layers.Dense(1)\n", - " \n", - " def call(self, features, hidden):\n", - " # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)\n", - " \n", - " # hidden shape == (batch_size, hidden_size)\n", - " # hidden_with_time_axis shape == (batch_size, 1, hidden_size)\n", - " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", - " \n", - " # score shape == (batch_size, 64, hidden_size)\n", - " score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))\n", - " \n", - " # attention_weights shape == (batch_size, 64, 1)\n", - " # we get 1 at the last axis because we are applying score to self.V\n", - " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", - " \n", - " # context_vector shape after sum == (batch_size, hidden_size)\n", - " context_vector = attention_weights * features\n", - " context_vector = tf.reduce_sum(context_vector, axis=1)\n", - " \n", - " return context_vector, attention_weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "AZ7R1RxHRPGf" - }, - "outputs": [], - "source": [ - "class CNN_Encoder(tf.keras.Model):\n", - " # Since we have already extracted the features and dumped it using pickle\n", - " # This encoder passes those features through a Fully connected layer\n", - " def __init__(self, embedding_dim):\n", - " super(CNN_Encoder, self).__init__()\n", - " # shape after fc == (batch_size, 64, embedding_dim)\n", - " self.fc = tf.keras.layers.Dense(embedding_dim)\n", - " \n", - " def call(self, x):\n", - " x = self.fc(x)\n", - " x = tf.nn.relu(x)\n", - " return x" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "V9UbGQmERPGi" - }, - "outputs": [], - "source": [ - "class RNN_Decoder(tf.keras.Model):\n", - " def __init__(self, embedding_dim, units, vocab_size):\n", - " super(RNN_Decoder, self).__init__()\n", - " self.units = units\n", - "\n", - " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", - " self.gru = gru(self.units)\n", - " self.fc1 = tf.keras.layers.Dense(self.units)\n", - " self.fc2 = tf.keras.layers.Dense(vocab_size)\n", - " \n", - " self.attention = BahdanauAttention(self.units)\n", - " \n", - " def call(self, x, features, hidden):\n", - " # defining attention as a separate model\n", - " context_vector, attention_weights = self.attention(features, hidden)\n", - " \n", - " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n", - " x = self.embedding(x)\n", - " \n", - " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n", - " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", - " \n", - " # passing the concatenated vector to the GRU\n", - " output, state = self.gru(x)\n", - " \n", - " # shape == (batch_size, max_length, hidden_size)\n", - " x = self.fc1(output)\n", - " \n", - " # x shape == (batch_size * max_length, hidden_size)\n", - " x = tf.reshape(x, (-1, x.shape[2]))\n", - " \n", - " # output shape == (batch_size * max_length, vocab)\n", - " x = self.fc2(x)\n", - "\n", - " return x, state, attention_weights\n", - "\n", - " def reset_state(self, batch_size):\n", - " return tf.zeros((batch_size, self.units))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "Qs_Sr03wRPGk" - }, - "outputs": [], - "source": [ - "encoder = CNN_Encoder(embedding_dim)\n", - "decoder = RNN_Decoder(embedding_dim, units, vocab_size)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "-bYN7xA0RPGl" - }, - "outputs": [], - "source": [ - "optimizer = tf.train.AdamOptimizer()\n", - "\n", - "# We are masking the loss calculated for padding\n", - "def loss_function(real, pred):\n", - " mask = 1 - np.equal(real, 0)\n", - " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", - " return tf.reduce_mean(loss_)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "PHod7t72RPGn" - }, - "source": [ - "## Training\n", - "\n", - "* We extract the features stored in the respective `.npy` files and then pass those features through the encoder.\n", - "* The encoder output, hidden state(initialized to 0) and the decoder input (which is the start token) is passed to the decoder.\n", - "* The decoder returns the predictions and the decoder hidden state.\n", - "* The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n", - "* Use teacher forcing to decide the next input to the decoder.\n", - "* Teacher forcing is the technique where the target word is passed as the next input to the decoder.\n", - "* The final step is to calculate the gradients and apply it to the optimizer and backpropagate.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "Vt4WZ5mhJE-E" - }, - "outputs": [], - "source": [ - "# adding this in a separate cell because if you run the training cell \n", - "# many times, the loss_plot array will be reset\n", - "loss_plot = []" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "UlA4VIQpRPGo" - }, - "outputs": [], - "source": [ - "EPOCHS = 20\n", - "\n", - "for epoch in range(EPOCHS):\n", - " start = time.time()\n", - " total_loss = 0\n", - " \n", - " for (batch, (img_tensor, target)) in enumerate(dataset):\n", - " loss = 0\n", - " \n", - " # initializing the hidden state for each batch\n", - " # because the captions are not related from image to image\n", - " hidden = decoder.reset_state(batch_size=target.shape[0])\n", - "\n", - " dec_input = tf.expand_dims([tokenizer.word_index['']] * BATCH_SIZE, 1)\n", - " \n", - " with tf.GradientTape() as tape:\n", - " features = encoder(img_tensor)\n", - " \n", - " for i in range(1, target.shape[1]):\n", - " # passing the features through the decoder\n", - " predictions, hidden, _ = decoder(dec_input, features, hidden)\n", - "\n", - " loss += loss_function(target[:, i], predictions)\n", - " \n", - " # using teacher forcing\n", - " dec_input = tf.expand_dims(target[:, i], 1)\n", - " \n", - " total_loss += (loss / int(target.shape[1]))\n", - " \n", - " variables = encoder.variables + decoder.variables\n", - " \n", - " gradients = tape.gradient(loss, variables) \n", - " \n", - " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n", - " \n", - " if batch % 100 == 0:\n", - " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, \n", - " batch, \n", - " loss.numpy() / int(target.shape[1])))\n", - " # storing the epoch end loss value to plot later\n", - " loss_plot.append(total_loss / len(cap_vector))\n", - " \n", - " print ('Epoch {} Loss {:.6f}'.format(epoch + 1, \n", - " total_loss/len(cap_vector)))\n", - " print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "1Wm83G-ZBPcC" - }, - "outputs": [], - "source": [ - "plt.plot(loss_plot)\n", - "plt.xlabel('Epochs')\n", - "plt.ylabel('Loss')\n", - "plt.title('Loss Plot')\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "xGvOcLQKghXN" - }, - "source": [ - "## Caption!\n", - "\n", - "* The evaluate function is similar to the training loop, except we don't use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n", - "* Stop predicting when the model predicts the end token.\n", - "* And store the attention weights for every time step." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "RCWpDtyNRPGs" - }, - "outputs": [], - "source": [ - "def evaluate(image):\n", - " attention_plot = np.zeros((max_length, attention_features_shape))\n", - "\n", - " hidden = decoder.reset_state(batch_size=1)\n", - "\n", - " temp_input = tf.expand_dims(load_image(image)[0], 0)\n", - " img_tensor_val = image_features_extract_model(temp_input)\n", - " img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))\n", - "\n", - " features = encoder(img_tensor_val)\n", - "\n", - " dec_input = tf.expand_dims([tokenizer.word_index['']], 0)\n", - " result = []\n", - "\n", - " for i in range(max_length):\n", - " predictions, hidden, attention_weights = decoder(dec_input, features, hidden)\n", - "\n", - " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", - "\n", - " predicted_id = tf.argmax(predictions[0]).numpy()\n", - " result.append(tokenizer.index_word[predicted_id])\n", - "\n", - " if tokenizer.index_word[predicted_id] == '':\n", - " return result, attention_plot\n", - "\n", - " dec_input = tf.expand_dims([predicted_id], 0)\n", - "\n", - " attention_plot = attention_plot[:len(result), :]\n", - " return result, attention_plot" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "fD_y7PD6RPGt" - }, - "outputs": [], - "source": [ - "def plot_attention(image, result, attention_plot):\n", - " temp_image = np.array(Image.open(image))\n", - "\n", - " fig = plt.figure(figsize=(10, 10))\n", - " \n", - " len_result = len(result)\n", - " for l in range(len_result):\n", - " temp_att = np.resize(attention_plot[l], (8, 8))\n", - " ax = fig.add_subplot(len_result//2, len_result//2, l+1)\n", - " ax.set_title(result[l])\n", - " img = ax.imshow(temp_image)\n", - " ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())\n", - "\n", - " plt.tight_layout()\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "K2s1A9eLRPEj" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n" + ] }, - "colab_type": "code", - "id": "io7ws3ReRPGv" - }, - "outputs": [], - "source": [ - "# captions on the validation set\n", - "rid = np.random.randint(0, len(img_name_val))\n", - "image = img_name_val[rid]\n", - "real_caption = ' '.join([tokenizer.index_word[i] for i in cap_val[rid] if i not in [0]])\n", - "result, attention_plot = evaluate(image)\n", - "\n", - "print ('Real Caption:', real_caption)\n", - "print ('Prediction Caption:', ' '.join(result))\n", - "plot_attention(image, result, attention_plot)\n", - "# opening the image\n", - "Image.open(img_name_val[rid])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Rprk3HEvZuxb" - }, - "source": [ - "## Try it on your own images\n", - "For fun, below we've provided a method you can use to caption your own images with the model we've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for weird results!)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Cffg2i257iMS" + }, + "source": [ + "# Image Captioning with Attention\n", + "\n", + "This example has moved:\n", + "\n", + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/image_captioning.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/image_captioning.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + ] }, - "colab_type": "code", - "id": "9Psd1quzaAWg" - }, - "outputs": [], - "source": [ - "image_url = 'https://tensorflow.org/images/surf.jpg'\n", - "image_extension = image_url[-4:]\n", - "image_path = tf.keras.utils.get_file('image'+image_extension, \n", - " origin=image_url)\n", - "\n", - "result, attention_plot = evaluate(image_path)\n", - "print ('Prediction Caption:', ' '.join(result))\n", - "plot_attention(image_path, result, attention_plot)\n", - "# opening the image\n", - "Image.open(image_path)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "VJZXyJco6uLO" - }, - "source": [ - "# Next steps\n", - "\n", - "Congrats! You've just trained an image captioning model with attention. Next, we recommend taking a look at this example [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb). It uses a similar architecture to translate between Spanish and English sentences. You can also experiment with training the code in this notebook on a different dataset." - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "default_view": {}, - "name": "image_captioning_with_attention.ipynb", - "private_outputs": true, - "provenance": [ { - "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg", - "timestamp": 1530222436922 + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "QASbY_HGo4Lq" + }, + "source": [ + "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n", + "\n", + "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n", + "\n", + "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "image_captioning_with_attention.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg", + "timestamp": 1530222436922 + } + ], + "toc_visible": true, + "version": "0.3.2" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" } - ], - "toc_visible": true, - "version": "0.3.2", - "views": {} - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.3" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb index bda9e77085e45ae31a228142135425e22a1c6780..c945c753b3ba36d16aa6985d23a5849f8f552304 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb @@ -13,633 +13,13 @@ "\n", "# Text Generation using a RNN\n", "\n", + "This example has moved.\n", + "\n", "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb\"\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/sequences/text_generation.ipynb\"\u003e\n", " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on Github\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "BwpJ5IffzRG6" - }, - "source": [ - "This notebook demonstrates how to generate text using an RNN using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). If you like, you can write a similar [model](https://github.com/fchollet/deep-learning-with-python-notebooks/blob/master/8.1-text-generation-with-lstm.ipynb) using less code. Here, we show a lower-level impementation that's useful to understand as prework before diving in to deeper examples in a similar, like [Neural Machine Translation with Attention](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", - "\n", - "This notebook is an end-to-end example. When you run it, it will download a dataset of Shakespeare's writing. We'll use a collection of plays, borrowed from Andrej Karpathy's excellent [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). The notebook will train a model, and use it to generate sample output.\n", - " \n", - "Here is the output(with start string='w') after training a single layer GRU for 30 epochs with the default settings below:\n", - "\n", - "```\n", - "were to the death of him\n", - "And nothing of the field in the view of hell,\n", - "When I said, banish him, I will not burn thee that would live.\n", - "\n", - "HENRY BOLINGBROKE:\n", - "My gracious uncle--\n", - "\n", - "DUKE OF YORK:\n", - "As much disgraced to the court, the gods them speak,\n", - "And now in peace himself excuse thee in the world.\n", - "\n", - "HORTENSIO:\n", - "Madam, 'tis not the cause of the counterfeit of the earth,\n", - "And leave me to the sun that set them on the earth\n", - "And leave the world and are revenged for thee.\n", - "\n", - "GLOUCESTER:\n", - "I would they were talking with the very name of means\n", - "To make a puppet of a guest, and therefore, good Grumio,\n", - "Nor arm'd to prison, o' the clouds, of the whole field,\n", - "With the admire\n", - "With the feeding of thy chair, and we have heard it so,\n", - "I thank you, sir, he is a visor friendship with your silly your bed.\n", - "\n", - "SAMPSON:\n", - "I do desire to live, I pray: some stand of the minds, make thee remedies\n", - "With the enemies of my soul.\n", - "\n", - "MENENIUS:\n", - "I'll keep the cause of my mistress.\n", - "\n", - "POLIXENES:\n", - "My brother Marcius!\n", - "\n", - "Second Servant:\n", - "Will't ple\n", - "```\n", - "\n", - "Of course, while some of the sentences are grammatical, most do not make sense. But, consider:\n", - "\n", - "* Our model is character based (when we began training, it did not yet know how to spell a valid English word, or that words were even a unit of text).\n", - "\n", - "* The structure of the output resembles a play (blocks begin with a speaker name, in all caps similar to the original text). Sentences generally end with a period. If you look at the text from a distance (or don't read the invididual words too closely, it appears as if it's an excerpt from a play).\n", - "\n", - "As a next step, you can experiment training the model on a different dataset - any large text file(ASCII) will do, and you can modify a single line of code below to make that change. Have fun!\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "R3p22DBDsaCA" - }, - "source": [ - "## Install unidecode library\n", - "A helpful library to convert unicode to ASCII." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "wZ6LOM12wKGH" - }, - "outputs": [], - "source": [ - "!pip install unidecode" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "WGyKZj3bzf9p" - }, - "source": [ - "## Import tensorflow and enable eager execution." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "yG_n40gFzf9s" - }, - "outputs": [], - "source": [ - "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", - "import tensorflow as tf\n", - "\n", - "# Note: Once you enable eager execution, it cannot be disabled. \n", - "tf.enable_eager_execution()\n", - "\n", - "import numpy as np\n", - "import os\n", - "import re\n", - "import random\n", - "import unidecode\n", - "import time" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "EHDoRoc5PKWz" - }, - "source": [ - "## Download the dataset\n", - "\n", - "In this example, we will use the [shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt). You can use any other dataset that you like.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "pD_55cOxLkAb" - }, - "outputs": [], - "source": [ - "path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "UHjdCjDuSvX_" - }, - "source": [ - "## Read the dataset\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "-E5JvY3wzf94" - }, - "outputs": [], - "source": [ - "text = unidecode.unidecode(open(path_to_file).read())\n", - "# length of text is the number of characters in it\n", - "print (len(text))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Il9ww98izf-D" - }, - "source": [ - "Creating dictionaries to map from characters to their indices and vice-versa, which will be used to vectorize the inputs" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "IalZLbvOzf-F" - }, - "outputs": [], - "source": [ - "# unique contains all the unique characters in the file\n", - "unique = sorted(set(text))\n", - "\n", - "# creating a mapping from unique characters to indices\n", - "char2idx = {u:i for i, u in enumerate(unique)}\n", - "idx2char = {i:u for i, u in enumerate(unique)}" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "1v_qUYfAzf-I" - }, - "outputs": [], - "source": [ - "# setting the maximum length sentence we want for a single input in characters\n", - "max_length = 100\n", - "\n", - "# length of the vocabulary in chars\n", - "vocab_size = len(unique)\n", - "\n", - "# the embedding dimension \n", - "embedding_dim = 256\n", - "\n", - "# number of RNN (here GRU) units\n", - "units = 1024\n", - "\n", - "# batch size \n", - "BATCH_SIZE = 64\n", - "\n", - "# buffer size to shuffle our dataset\n", - "BUFFER_SIZE = 10000" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "LFjSVAlWzf-N" - }, - "source": [ - "## Creating the input and output tensors\n", - "\n", - "Vectorizing the input and the target text because our model cannot understand strings only numbers.\n", - "\n", - "But first, we need to create the input and output vectors.\n", - "Remember the max_length we set above, we will use it here. We are creating **max_length** chunks of input, where each input vector is all the characters in that chunk except the last and the target vector is all the characters in that chunk except the first.\n", - "\n", - "For example, consider that the string = 'tensorflow' and the max_length is 9\n", - "\n", - "So, the `input = 'tensorflo'` and `output = 'ensorflow'`\n", - "\n", - "After creating the vectors, we convert each character into numbers using the **char2idx** dictionary we created above." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "0UHJDA39zf-O" - }, - "outputs": [], - "source": [ - "input_text = []\n", - "target_text = []\n", - "\n", - "for f in range(0, len(text)-max_length, max_length):\n", - " inps = text[f:f+max_length]\n", - " targ = text[f+1:f+1+max_length]\n", - "\n", - " input_text.append([char2idx[i] for i in inps])\n", - " target_text.append([char2idx[t] for t in targ])\n", - " \n", - "print (np.array(input_text).shape)\n", - "print (np.array(target_text).shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "MJdfPmdqzf-R" - }, - "source": [ - "## Creating batches and shuffling them using tf.data" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "p2pGotuNzf-S" - }, - "outputs": [], - "source": [ - "dataset = tf.data.Dataset.from_tensor_slices((input_text, target_text)).shuffle(BUFFER_SIZE)\n", - "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "m8gPwEjRzf-Z" - }, - "source": [ - "## Creating the model\n", - "\n", - "We use the Model Subclassing API which gives us full flexibility to create the model and change it however we like. We use 3 layers to define our model.\n", - "\n", - "* Embedding layer\n", - "* GRU layer (you can use an LSTM layer here)\n", - "* Fully connected layer" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "P3KTiiInzf-a" - }, - "outputs": [], - "source": [ - "class Model(tf.keras.Model):\n", - " def __init__(self, vocab_size, embedding_dim, units, batch_size):\n", - " super(Model, self).__init__()\n", - " self.units = units\n", - " self.batch_sz = batch_size\n", - "\n", - " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", - "\n", - " if tf.test.is_gpu_available():\n", - " self.gru = tf.keras.layers.CuDNNGRU(self.units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_initializer='glorot_uniform')\n", - " else:\n", - " self.gru = tf.keras.layers.GRU(self.units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_activation='sigmoid', \n", - " recurrent_initializer='glorot_uniform')\n", - "\n", - " self.fc = tf.keras.layers.Dense(vocab_size)\n", - " \n", - " def call(self, x, hidden):\n", - " x = self.embedding(x)\n", - "\n", - " # output shape == (batch_size, max_length, hidden_size) \n", - " # states shape == (batch_size, hidden_size)\n", - "\n", - " # states variable to preserve the state of the model\n", - " # this will be used to pass at every step to the model while training\n", - " output, states = self.gru(x, initial_state=hidden)\n", - "\n", - "\n", - " # reshaping the output so that we can pass it to the Dense layer\n", - " # after reshaping the shape is (batch_size * max_length, hidden_size)\n", - " output = tf.reshape(output, (-1, output.shape[2]))\n", - "\n", - " # The dense layer will output predictions for every time_steps(max_length)\n", - " # output shape after the dense layer == (max_length * batch_size, vocab_size)\n", - " x = self.fc(output)\n", - "\n", - " return x, states" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "trpqTWyvk0nr" - }, - "source": [ - "## Call the model and set the optimizer and the loss function" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "7t2XrzEOzf-e" - }, - "outputs": [], - "source": [ - "model = Model(vocab_size, embedding_dim, units, BATCH_SIZE)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "dkjWIATszf-h" - }, - "outputs": [], - "source": [ - "optimizer = tf.train.AdamOptimizer()\n", - "\n", - "# using sparse_softmax_cross_entropy so that we don't have to create one-hot vectors\n", - "def loss_function(real, preds):\n", - " return tf.losses.sparse_softmax_cross_entropy(labels=real, logits=preds)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "3K6s6F79P7za" - }, - "source": [ - "## Checkpoints (Object-based saving)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "oAGisDdfP9rL" - }, - "outputs": [], - "source": [ - "checkpoint_dir = './training_checkpoints'\n", - "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", - "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n", - " model=model)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "lPrP0XMUzf-p" - }, - "source": [ - "## Train the model\n", - "\n", - "Here we will use a custom training loop with the help of GradientTape()\n", - "\n", - "* We initialize the hidden state of the model with zeros and shape == (batch_size, number of rnn units). We do this by calling the function defined while creating the model.\n", - "\n", - "* Next, we iterate over the dataset(batch by batch) and calculate the **predictions and the hidden states** associated with that input.\n", - "\n", - "* There are a lot of interesting things happening here.\n", - " * The model gets hidden state(initialized with 0), lets call that **H0** and the first batch of input, lets call that **I0**.\n", - " * The model then returns the predictions **P1** and **H1**.\n", - " * For the next batch of input, the model receives **I1** and **H1**.\n", - " * The interesting thing here is that we pass **H1** to the model with **I1** which is how the model learns. The context learned from batch to batch is contained in the **hidden state**.\n", - " * We continue doing this until the dataset is exhausted and then we start a new epoch and repeat this.\n", - "\n", - "* After calculating the predictions, we calculate the **loss** using the loss function defined above. Then we calculate the gradients of the loss with respect to the model variables(input)\n", - "\n", - "* Finally, we take a step in that direction with the help of the optimizer using the apply_gradients function.\n", - "\n", - "Note:- If you are running this notebook in Colab which has a **Tesla K80 GPU** it takes about 23 seconds per epoch.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "d4tSNwymzf-q" - }, - "outputs": [], - "source": [ - "# Training step\n", - "\n", - "EPOCHS = 20\n", - "\n", - "for epoch in range(EPOCHS):\n", - " start = time.time()\n", - " \n", - " # initializing the hidden state at the start of every epoch\n", - " hidden = model.reset_states()\n", - " \n", - " for (batch, (inp, target)) in enumerate(dataset):\n", - " with tf.GradientTape() as tape:\n", - " # feeding the hidden state back into the model\n", - " # This is the interesting step\n", - " predictions, hidden = model(inp, hidden)\n", - " \n", - " # reshaping the target because that's how the \n", - " # loss function expects it\n", - " target = tf.reshape(target, (-1,))\n", - " loss = loss_function(target, predictions)\n", - " \n", - " grads = tape.gradient(loss, model.variables)\n", - " optimizer.apply_gradients(zip(grads, model.variables))\n", - "\n", - " if batch % 100 == 0:\n", - " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch+1,\n", - " batch,\n", - " loss))\n", - " # saving (checkpoint) the model every 5 epochs\n", - " if (epoch + 1) % 5 == 0:\n", - " checkpoint.save(file_prefix = checkpoint_prefix)\n", - "\n", - " print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))\n", - " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "01AR9vpNQMFF" - }, - "source": [ - "## Restore the latest checkpoint" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "tyvpYomYQQkF" - }, - "outputs": [], - "source": [ - "# restoring the latest checkpoint in checkpoint_dir\n", - "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "DjGz1tDkzf-u" - }, - "source": [ - "## Predicting using our trained model\n", - "\n", - "The below code block is used to generated the text\n", - "\n", - "* We start by choosing a start string and initializing the hidden state and setting the number of characters we want to generate.\n", - "\n", - "* We get predictions using the start_string and the hidden state\n", - "\n", - "* Then we use argmax to calculate the index of the predicted word. **We use this predicted word as our next input to the model**\n", - "\n", - "* **The hidden state returned by the model is fed back into the model so that it now has more context rather than just one word.** After we predict the next word, the modified hidden states are again fed back into the model, which is how it learns as it gets more context from the previously predicted words.\n", - "\n", - "* If you see the predictions, the model knows when to capitalize, make paragraphs and the text follows a shakespeare style of writing which is pretty awesome!" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "WvuwZBX5Ogfd" - }, - "outputs": [], - "source": [ - "# Evaluation step(generating text using the model learned)\n", - "\n", - "# number of characters to generate\n", - "num_generate = 1000\n", - "\n", - "# You can change the start string to experiment\n", - "start_string = 'Q'\n", - "# converting our start string to numbers(vectorizing!) \n", - "input_eval = [char2idx[s] for s in start_string]\n", - "input_eval = tf.expand_dims(input_eval, 0)\n", - "\n", - "# empty string to store our results\n", - "text_generated = ''\n", - "\n", - "# hidden state shape == (batch_size, number of rnn units); here batch size == 1\n", - "hidden = [tf.zeros((1, units))]\n", - "for i in range(num_generate):\n", - " predictions, hidden = model(input_eval, hidden)\n", - "\n", - " # using argmax to predict the word returned by the model\n", - " predicted_id = tf.argmax(predictions[-1]).numpy()\n", - " \n", - " # We pass the predicted word as the next input to the model\n", - " # along with the previous hidden state\n", - " input_eval = tf.expand_dims([predicted_id], 0)\n", - " \n", - " text_generated += idx2char[predicted_id]\n", - "\n", - "print (start_string + text_generated)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "AM2Uma_-yVIq" - }, - "source": [ - "## Next steps\n", - "\n", - "* Change the start string to a different character, or the start of a sentence.\n", - "* Experiment with training on a different, or with different parameters. [Project Gutenberg](http://www.gutenberg.org/ebooks/100), for example, contains a large collection of books.\n", - "* Add another RNN layer.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "gtEd86sX5cB2" - }, - "outputs": [], - "source": [ - "" + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/sequences/text_generation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] } ], diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/BUILD b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD index 7bdf9053de749af9d09b12ba7b848e21c1fdb8f0..35d509904211d98f124d2555fc48166e75cb0dd9 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/BUILD +++ b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD @@ -28,7 +28,7 @@ py_library( cuda_py_test( name = "l2hmc_test", - size = "large", + size = "medium", srcs = ["l2hmc_test.py"], additional_deps = [ ":l2hmc", @@ -36,4 +36,8 @@ cuda_py_test( "//tensorflow/contrib/eager/python:tfe", "//third_party/py/numpy", ], + shard_count = 4, + tags = [ + "oss_serial", + ], ) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD index 74ce9e84f013d79b3a33ffa79993980b561e366d..30afef83bc5c6c164c8456ed472f4d6064068a25 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD +++ b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD @@ -9,6 +9,13 @@ py_binary( name = "linear_regression", srcs = ["linear_regression.py"], srcs_version = "PY2AND3", + deps = [":linear_regression_lib"], +) + +py_library( + name = "linear_regression_lib", + srcs = ["linear_regression.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", @@ -20,10 +27,13 @@ cuda_py_test( size = "small", srcs = ["linear_regression_test.py"], additional_deps = [ - ":linear_regression", + ":linear_regression_lib", "//tensorflow:tensorflow_py", ], - tags = ["no_windows"], # TODO: needs investigation on Windows + tags = [ + "no_windows", # TODO: needs investigation on Windows + "oss_serial", + ], ) cuda_py_test( @@ -31,7 +41,7 @@ cuda_py_test( size = "small", srcs = ["linear_regression_graph_test.py"], additional_deps = [ - ":linear_regression", + ":linear_regression_lib", "//tensorflow:tensorflow_py", ], ) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index 099b712fc06d1d3eb9ab4095f8db7283690bda76..206ef9409df7b1dc21de42ba919d2ba97f334a8c 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -56,7 +56,7 @@ class LinearModel(tf.keras.Model): def mean_square_loss(model, xs, ys): - return tf.reduce_mean(tf.square(tf.subtract(model(xs), ys))) + return tf.reduce_mean(tf.squared_difference(model(xs), ys)) def fit(model, dataset, optimizer, verbose=False, logdir=None): diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index 66d52a74943d0d81fde05ce51b019558b327978d..436e887736158ec1ba8e46eac8de4ac7b8e6be01 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -1,11 +1,28 @@ { + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "nmt_with_attention.ipynb", + "version": "0.3.2", + "provenance": [], + "private_outputs": true, + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "accelerator": "GPU" + }, "cells": [ { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "AOpGoE2T-YXS" }, + "cell_type": "markdown", "source": [ "##### Copyright 2018 The TensorFlow Authors.\n", "\n", @@ -13,19 +30,19 @@ "\n", "# Neural Machine Translation with Attention\n", "\n", - "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"\u003e\n", - " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", - "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + "
\n", + "\n", + " Run in Google Colab \n", + "\n", + "View source on GitHub
" ] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "CiwtNgENbx2g" }, + "cell_type": "markdown", "source": [ "This notebook trains a sequence to sequence (seq2seq) model for Spanish to English translation using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). This is an advanced example that assumes some knowledge of sequence to sequence models.\n", "\n", @@ -33,24 +50,22 @@ "\n", "The translation quality is reasonable for a toy example, but the generated attention plot is perhaps more interesting. This shows which parts of the input sentence has the model's attention while translating:\n", "\n", - "\u003cimg src=\"https://tensorflow.org/images/spanish-english.png\" alt=\"spanish-english attention plot\"\u003e\n", + "\"spanish-english\n", "\n", "Note: This example takes approximately 10 mintues to run on a single P100 GPU." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "tnxXKDjq3jEL" + "id": "tnxXKDjq3jEL", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "from __future__ import absolute_import, division, print_function\n", "\n", - "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", + "# Import TensorFlow >= 1.10 and enable eager execution\n", "import tensorflow as tf\n", "\n", "tf.enable_eager_execution()\n", @@ -65,14 +80,16 @@ "import time\n", "\n", "print(tf.__version__)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "wfodePkj3jEa" }, + "cell_type": "markdown", "source": [ "## Download and prepare the dataset\n", "\n", @@ -91,14 +108,12 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "kRVATYOgJs1b" + "id": "kRVATYOgJs1b", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Download the file\n", "path_to_zip = tf.keras.utils.get_file(\n", @@ -106,17 +121,17 @@ " extract=True)\n", "\n", "path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\"" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "rd0jw-eC3jEh" + "id": "rd0jw-eC3jEh", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Converts the unicode file to ascii\n", "def unicode_to_ascii(s):\n", @@ -128,7 +143,7 @@ " w = unicode_to_ascii(w.lower().strip())\n", " \n", " # creating a space between a word and the punctuation following it\n", - " # eg: \"he is a boy.\" =\u003e \"he is a boy .\" \n", + " # eg: \"he is a boy.\" => \"he is a boy .\" \n", " # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation\n", " w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n", " w = re.sub(r'[\" \"]+', \" \", w)\n", @@ -140,19 +155,19 @@ " \n", " # adding a start and an end token to the sentence\n", " # so that the model know when to start and stop predicting.\n", - " w = '\u003cstart\u003e ' + w + ' \u003cend\u003e'\n", + " w = ' ' + w + ' '\n", " return w" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "OHn4Dct23jEm" + "id": "OHn4Dct23jEm", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# 1. Remove the accents\n", "# 2. Clean the sentences\n", @@ -163,20 +178,20 @@ " word_pairs = [[preprocess_sentence(w) for w in l.split('\\t')] for l in lines[:num_examples]]\n", " \n", " return word_pairs" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "9xbqO7Iie9bb" + "id": "9xbqO7Iie9bb", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ - "# This class creates a word -\u003e index mapping (e.g,. \"dad\" -\u003e 5) and vice-versa \n", - "# (e.g., 5 -\u003e \"dad\") for each language,\n", + "# This class creates a word -> index mapping (e.g,. \"dad\" -> 5) and vice-versa \n", + "# (e.g., 5 -> \"dad\") for each language,\n", "class LanguageIndex():\n", " def __init__(self, lang):\n", " self.lang = lang\n", @@ -192,23 +207,23 @@ " \n", " self.vocab = sorted(self.vocab)\n", " \n", - " self.word2idx['\u003cpad\u003e'] = 0\n", + " self.word2idx[''] = 0\n", " for index, word in enumerate(self.vocab):\n", " self.word2idx[word] = index + 1\n", " \n", " for word, index in self.word2idx.items():\n", " self.idx2word[index] = word" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "eAY9k49G3jE_" + "id": "eAY9k49G3jE_", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "def max_length(tensor):\n", " return max(len(t) for t in tensor)\n", @@ -244,71 +259,71 @@ " padding='post')\n", " \n", " return input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_tar" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "GOi42V79Ydlr" }, + "cell_type": "markdown", "source": [ "### Limit the size of the dataset to experiment faster (optional)\n", "\n", - "Training on the complete dataset of \u003e100,000 sentences will take a long time. To train faster, we can limit the size of the dataset to 30,000 sentences (of course, translation quality degrades with less data):" + "Training on the complete dataset of >100,000 sentences will take a long time. To train faster, we can limit the size of the dataset to 30,000 sentences (of course, translation quality degrades with less data):" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "cnxC7q-j3jFD" + "id": "cnxC7q-j3jFD", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Try experimenting with the size of that dataset\n", "num_examples = 30000\n", "input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_targ = load_dataset(path_to_file, num_examples)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "4QILQkOs3jFG" + "id": "4QILQkOs3jFG", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Creating training and validation sets using an 80-20 split\n", "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)\n", "\n", "# Show length\n", "len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "rgCLkfv5uO3d" }, + "cell_type": "markdown", "source": [ "### Create a tf.data dataset" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "TqHsArVZ3jFS" + "id": "TqHsArVZ3jFS", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "BUFFER_SIZE = len(input_tensor_train)\n", "BATCH_SIZE = 64\n", @@ -320,27 +335,29 @@ "\n", "dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n", "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "TNfHIF71ulLu" }, + "cell_type": "markdown", "source": [ "## Write the encoder and decoder model\n", "\n", - "Here, we'll implement an encoder-decoder model with attention which you can read about in the TensorFlow [Neural Machine Translation (seq2seq) tutorial](https://www.tensorflow.org/tutorials/seq2seq). This example uses a more recent set of APIs. This notebook implements the [attention equations](https://www.tensorflow.org/tutorials/seq2seq#background_on_the_attention_mechanism) from the seq2seq tutorial. The following diagram shows that each input words is assigned a weight by the attention mechanism which is then used by the decoder to predict the next word in the sentence.\n", + "Here, we'll implement an encoder-decoder model with attention which you can read about in the TensorFlow [Neural Machine Translation (seq2seq) tutorial](https://github.com/tensorflow/nmt). This example uses a more recent set of APIs. This notebook implements the [attention equations](https://github.com/tensorflow/nmt#background-on-the-attention-mechanism) from the seq2seq tutorial. The following diagram shows that each input word is assigned a weight by the attention mechanism which is then used by the decoder to predict the next word in the sentence.\n", "\n", - "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_mechanism.jpg\" width=\"500\" alt=\"attention mechanism\"\u003e\n", + "\"attention\n", "\n", "The input is put through an encoder model which gives us the encoder output of shape *(batch_size, max_length, hidden_size)* and the encoder hidden state of shape *(batch_size, hidden_size)*. \n", "\n", "Here are the equations that are implemented:\n", "\n", - "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_0.jpg\" alt=\"attention equation 0\" width=\"800\"\u003e\n", - "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_1.jpg\" alt=\"attention equation 1\" width=\"800\"\u003e\n", + "\"attention\n", + "\"attention\n", "\n", "We're using *Bahdanau attention*. Lets decide on notation before writing the simplified form:\n", "\n", @@ -362,14 +379,12 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "avyJ_4VIUoHb" + "id": "avyJ_4VIUoHb", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "def gru(units):\n", " # If you have a GPU, we recommend using CuDNNGRU(provides a 3x speedup than GRU)\n", @@ -385,17 +400,17 @@ " return_state=True, \n", " recurrent_activation='sigmoid', \n", " recurrent_initializer='glorot_uniform')" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "nZ2rI24i3jFg" + "id": "nZ2rI24i3jFg", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "class Encoder(tf.keras.Model):\n", " def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n", @@ -412,17 +427,17 @@ " \n", " def initialize_hidden_state(self):\n", " return tf.zeros((self.batch_sz, self.enc_units))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "yJ_B3mhW3jFk" + "id": "yJ_B3mhW3jFk", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "class Decoder(tf.keras.Model):\n", " def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n", @@ -476,41 +491,41 @@ " \n", " def initialize_hidden_state(self):\n", " return tf.zeros((self.batch_sz, self.dec_units))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "P5UY8wko3jFp" + "id": "P5UY8wko3jFp", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n", "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "_ch_71VbIRfK" }, + "cell_type": "markdown", "source": [ "## Define the optimizer and the loss function" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "WmTHr5iV3jFr" + "id": "WmTHr5iV3jFr", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "optimizer = tf.train.AdamOptimizer()\n", "\n", @@ -519,41 +534,43 @@ " mask = 1 - np.equal(real, 0)\n", " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", " return tf.reduce_mean(loss_)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "DMVWzzsfNl4e" }, + "cell_type": "markdown", "source": [ "## Checkpoints (Object-based saving)" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "Zj8bXQTgNwrF" + "id": "Zj8bXQTgNwrF", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "checkpoint_dir = './training_checkpoints'\n", "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n", " encoder=encoder,\n", " decoder=decoder)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "hpObfY22IddU" }, + "cell_type": "markdown", "source": [ "## Training\n", "\n", @@ -567,14 +584,12 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "ddefjBMa3jF0" + "id": "ddefjBMa3jF0", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "EPOCHS = 10\n", "\n", @@ -592,7 +607,7 @@ " \n", " dec_hidden = enc_hidden\n", " \n", - " dec_input = tf.expand_dims([targ_lang.word2idx['\u003cstart\u003e']] * BATCH_SIZE, 1) \n", + " dec_input = tf.expand_dims([targ_lang.word2idx['']] * BATCH_SIZE, 1) \n", " \n", " # Teacher forcing - feeding the target as the next input\n", " for t in range(1, targ.shape[1]):\n", @@ -625,14 +640,16 @@ " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n", " total_loss / N_BATCH))\n", " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "mU3Ce8M6I3rz" }, + "cell_type": "markdown", "source": [ "## Translate\n", "\n", @@ -644,14 +661,12 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "EbQpyYs13jF_" + "id": "EbQpyYs13jF_", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "def evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n", " attention_plot = np.zeros((max_length_targ, max_length_inp))\n", @@ -668,7 +683,7 @@ " enc_out, enc_hidden = encoder(inputs, hidden)\n", "\n", " dec_hidden = enc_hidden\n", - " dec_input = tf.expand_dims([targ_lang.word2idx['\u003cstart\u003e']], 0)\n", + " dec_input = tf.expand_dims([targ_lang.word2idx['']], 0)\n", "\n", " for t in range(max_length_targ):\n", " predictions, dec_hidden, attention_weights = decoder(dec_input, dec_hidden, enc_out)\n", @@ -681,24 +696,24 @@ "\n", " result += targ_lang.idx2word[predicted_id] + ' '\n", "\n", - " if targ_lang.idx2word[predicted_id] == '\u003cend\u003e':\n", + " if targ_lang.idx2word[predicted_id] == '':\n", " return result, sentence, attention_plot\n", " \n", " # the predicted ID is fed back into the model\n", " dec_input = tf.expand_dims([predicted_id], 0)\n", "\n", " return result, sentence, attention_plot" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "s5hQWlbN3jGF" + "id": "s5hQWlbN3jGF", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# function for plotting the attention weights\n", "def plot_attention(attention, sentence, predicted_sentence):\n", @@ -712,17 +727,17 @@ " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n", "\n", " plt.show()" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "sl9zUHzg3jGI" + "id": "sl9zUHzg3jGI", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "def translate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n", " result, sentence, attention_plot = evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)\n", @@ -732,91 +747,93 @@ " \n", " attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]\n", " plot_attention(attention_plot, sentence.split(' '), result.split(' '))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "n250XbnjOaqP" }, + "cell_type": "markdown", "source": [ "## Restore the latest checkpoint and test" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "UJpT9D5_OgP6" + "id": "UJpT9D5_OgP6", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# restoring the latest checkpoint in checkpoint_dir\n", "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "WrAM0FDomq3E" + "id": "WrAM0FDomq3E", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "translate(u'hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "zSx2iM36EZQZ" + "id": "zSx2iM36EZQZ", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "translate(u'esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "A3LLCx3ZE0Ls" + "id": "A3LLCx3ZE0Ls", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "translate(u'todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "DUQVLVqUE1YW" + "id": "DUQVLVqUE1YW", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# wrong translation\n", "translate(u'trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "RTe5P5ioMJwN" }, + "cell_type": "markdown", "source": [ "## Next steps\n", "\n", @@ -824,31 +841,5 @@ "* Experiment with training on a larger dataset, or using more epochs\n" ] } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "name": "nmt_with_attention.ipynb", - "private_outputs": true, - "provenance": [ - { - "file_id": "1C4fpM7_7IL8ZzF7Gc5abywqQjeQNS2-U", - "timestamp": 1527858391290 - }, - { - "file_id": "1pExo6aUuw0S6MISFWoinfJv0Ftm9V4qv", - "timestamp": 1527776041613 - } - ], - "toc_visible": true, - "version": "0.3.2" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/resnet50/BUILD b/tensorflow/contrib/eager/python/examples/resnet50/BUILD index f3135a9668fc0dc7faa93a5f119b53f3efd34c6e..f2851d97223e483da11120f1fe3f0a2f641dfb81 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/BUILD +++ b/tensorflow/contrib/eager/python/examples/resnet50/BUILD @@ -27,7 +27,7 @@ py_library( cuda_py_test( name = "resnet50_test", - size = "large", + size = "medium", srcs = ["resnet50_test.py"], additional_deps = [ ":resnet50", @@ -35,17 +35,19 @@ cuda_py_test( "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "noasan", # Fix b/118130911 "nomsan", # Fix b/118130911 "notsan", # Fix b/118130911 "optonly", + "oss_serial", ], ) cuda_py_test( name = "resnet50_graph_test", - size = "large", + size = "medium", srcs = ["resnet50_graph_test.py"], additional_deps = [ ":resnet50", @@ -53,10 +55,12 @@ cuda_py_test( "//third_party/py/numpy", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "noasan", "nomsan", "notsan", "optonly", + "oss_serial", ], ) diff --git a/tensorflow/contrib/eager/python/examples/revnet/BUILD b/tensorflow/contrib/eager/python/examples/revnet/BUILD index 4f0d46b1bae3760a63b2abe871034bdedf258f07..cb207b8ddf3641a68a114386f6a95a26ce2b74d6 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/BUILD +++ b/tensorflow/contrib/eager/python/examples/revnet/BUILD @@ -67,30 +67,36 @@ py_library( # Tests cuda_py_test( name = "ops_test", - size = "large", + size = "medium", srcs = ["ops_test.py"], additional_deps = [ ":ops", "//tensorflow:tensorflow_py", ], + shard_count = 4, + tags = [ + "oss_serial", + ], ) cuda_py_test( name = "blocks_test", - size = "large", + size = "medium", srcs = ["blocks_test.py"], additional_deps = [ ":blocks", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ + "no_oss", # b/123045964 "optonly", ], ) cuda_py_test( name = "revnet_test", - size = "large", + size = "medium", srcs = ["revnet_test.py"], additional_deps = [ ":blocks_test", @@ -98,9 +104,11 @@ cuda_py_test( ":revnet", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "no_pip", # depends on blocks_test, which is not available in pip package "optonly", + "oss_serial", ], ) @@ -127,6 +135,13 @@ py_binary( name = "main", srcs = ["main.py"], srcs_version = "PY2AND3", + deps = [":main_lib"], +) + +py_library( + name = "main_lib", + srcs = ["main.py"], + srcs_version = "PY2AND3", deps = [ ":cifar_input", ":config", @@ -141,7 +156,7 @@ py_binary( srcs_version = "PY2AND3", deps = [ ":cifar_input", - ":main", + ":main_lib", ":revnet", "//tensorflow:tensorflow_py", ], @@ -153,7 +168,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":cifar_input", - ":main", + ":main_lib", ":revnet", "//tensorflow:tensorflow_py", ], @@ -165,7 +180,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":cifar_input", - ":main", + ":main_lib", ":revnet", "//tensorflow:tensorflow_py", ], diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py index 1f2cb14972f0b92d29489adff8f94e790e1ec4ed..7406787ba438345dc485c50e347e40597b2037f5 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py @@ -96,6 +96,7 @@ class RevNet(tf.keras.Model): def call(self, inputs, training=True): """Forward pass.""" + saved_hidden = None if training: saved_hidden = [inputs] diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD index d500b632ebb97fd12ded3a215b0f1a686194874f..f4dbe7ac16f734f7bee045bc71e9559b630adf81 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD @@ -9,6 +9,13 @@ py_binary( name = "rnn_colorbot", srcs = ["rnn_colorbot.py"], srcs_version = "PY2AND3", + deps = [":rnn_colorbot_lib"], +) + +py_library( + name = "rnn_colorbot_lib", + srcs = ["rnn_colorbot.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", @@ -21,8 +28,11 @@ cuda_py_test( name = "rnn_colorbot_test", srcs = ["rnn_colorbot_test.py"], additional_deps = [ - ":rnn_colorbot", + ":rnn_colorbot_lib", "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], + tags = [ + "oss_serial", + ], ) diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 74ebb1ec77131a560b1ebfd062c690920c35e261..1c718a5ce3d8e1541656d92fd5e8dad6c6683c4c 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -207,7 +207,7 @@ class RNNColorbot(tf.keras.Model): def loss(labels, predictions): """Computes mean squared loss.""" - return tf.reduce_mean(tf.square(predictions - labels)) + return tf.reduce_mean(tf.squared_difference(predictions, labels)) def test(model, eval_data): diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD index 2cc2fcbfeb21ee6218d7912d9a93ea2f7b2ea226..43a6ca526d3a0aecda2c8df865a0487ac28758ab 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD @@ -9,6 +9,13 @@ py_binary( name = "rnn_ptb", srcs = ["rnn_ptb.py"], srcs_version = "PY2AND3", + deps = [":rnn_ptb_lib"], +) + +py_library( + name = "rnn_ptb_lib", + srcs = ["rnn_ptb.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", @@ -21,18 +28,22 @@ cuda_py_test( name = "rnn_ptb_test", srcs = ["rnn_ptb_test.py"], additional_deps = [ - ":rnn_ptb", + ":rnn_ptb_lib", "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], + tags = ["no_oss"], # b/123045964 ) cuda_py_test( name = "rnn_ptb_graph_test", srcs = ["rnn_ptb_graph_test.py"], additional_deps = [ - ":rnn_ptb", + ":rnn_ptb_lib", "//third_party/py/numpy", "//tensorflow:tensorflow_py", ], + tags = [ + "oss_serial", + ], ) diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index 15776c694e92825895437a4c1547699f6d9269fb..9b5a2c947b153308c83f1a922d06c034ec5f9ddf 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -128,7 +128,7 @@ class PTBModel(tf.keras.Model): self.linear = layers.Dense( vocab_size, kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1)) - self._output_shape = [-1, embedding_dim] + self._output_shape = [-1, hidden_dim] def call(self, input_seq, training): """Run the forward pass of PTBModel. diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD index 5966f1d4873e8e77b3ad5914da7bfc7e69d4e341..9b0fbaa6793e28d327745767e6ccd3085211ff7d 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/BUILD +++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD @@ -42,5 +42,6 @@ cuda_py_test( "no-internal-py3", # flaky "no_cuda_on_cpu_tap", "no_pip", # because spinn.py is under third_party/. + "oss_serial", ], ) diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 566246de4957c1dc5919c10e22146706f9e50be8..c8d9266672a8b87d32338ea7c4f74fb40d41c767 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -37,7 +37,7 @@ from tensorflow.python.training.checkpointable import base as checkpointable _to_replace = re.compile("[^A-Za-z0-9.]") -class Metric(checkpointable.CheckpointableBase): +class Metric(checkpointable.Checkpointable): """A metric holds state for aggregating statistics over an evaluation run. Example use with eager execution: diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 31481d7685c79b76c40b1f8041441a0e71d3b00e..b82e1bb71bce9a28d7bbbf961cc6d5e25dd18acf 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -138,7 +138,7 @@ from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Vari from tensorflow.python.ops.variable_scope import EagerVariableStore from tensorflow.python.ops import script_ops from tensorflow.python.ops import template -from tensorflow.python.training.checkpointable.tracking import Checkpointable +from tensorflow.python.training.checkpointable.tracking import AutoCheckpointable as Checkpointable from tensorflow.python.training.checkpointable.util import CheckpointableSaver from tensorflow.python.training.checkpointable.util import Checkpoint from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index e344d7a23b55134612aab430b50cf065bd1095e4..cb86efb8da72f168b54f04773289a6fe421282b1 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -28,7 +28,6 @@ tf_custom_op_py_library( "python/ops/wals.py", ], dso = [ - ":python/ops/_clustering_ops.so", ":python/ops/_factorization_ops.so", ], kernels = [ @@ -38,12 +37,12 @@ tf_custom_op_py_library( srcs_version = "PY2AND3", deps = [ ":factorization_ops_test_utils_py", - ":gen_clustering_ops", ":gen_factorization_ops", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", + "//tensorflow/python:clustering_ops_gen", "//tensorflow/python:control_flow_ops", "//tensorflow/python:data_flow_ops", "//tensorflow/python:embedding_ops", @@ -77,17 +76,6 @@ py_library( ], ) -# Ops -tf_custom_op_library( - name = "python/ops/_clustering_ops.so", - srcs = [ - "ops/clustering_ops.cc", - ], - deps = [ - "//tensorflow/contrib/factorization/kernels:clustering_ops", - ], -) - tf_custom_op_library( name = "python/ops/_factorization_ops.so", srcs = [ @@ -100,26 +88,16 @@ tf_custom_op_library( ) tf_gen_op_libs([ - "clustering_ops", "factorization_ops", ]) cc_library( name = "all_ops", deps = [ - ":clustering_ops_op_lib", ":factorization_ops_op_lib", ], ) -tf_gen_op_wrapper_py( - name = "gen_clustering_ops", - out = "python/ops/gen_clustering_ops.py", - deps = [ - ":clustering_ops_op_lib", - ], -) - tf_gen_op_wrapper_py( name = "gen_factorization_ops", out = "python/ops/gen_factorization_ops.py", diff --git a/tensorflow/contrib/factorization/kernels/BUILD b/tensorflow/contrib/factorization/kernels/BUILD index ea8b9a17a27093cb57564861815edd6ecb18a014..23d7e088d067effa446e4bcdc9609db612066568 100644 --- a/tensorflow/contrib/factorization/kernels/BUILD +++ b/tensorflow/contrib/factorization/kernels/BUILD @@ -11,7 +11,6 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") cc_library( name = "all_kernels", deps = [ - ":clustering_ops", ":masked_matmul_ops", ":wals_solver_ops", "@protobuf_archive//:protobuf_headers", @@ -29,17 +28,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "clustering_ops", - srcs = ["clustering_ops.cc"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - cc_library( name = "masked_matmul_ops", srcs = ["masked_matmul_ops.cc"], @@ -51,19 +39,3 @@ cc_library( ], alwayslink = 1, ) - -tf_cc_test( - name = "clustering_ops_test", - srcs = ["clustering_ops_test.cc"], - deps = [ - ":clustering_ops", - "//tensorflow/contrib/factorization:clustering_ops_op_lib", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) diff --git a/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc b/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc index a8c5d0763c28ba2b54f217405f0da65533f26b91..68078ba8bbb07b4344c19d554012d214229f9c4f 100644 --- a/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc +++ b/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc @@ -19,12 +19,12 @@ #include #include +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" diff --git a/tensorflow/contrib/factorization/ops/clustering_ops.cc b/tensorflow/contrib/factorization/ops/clustering_ops.cc deleted file mode 100644 index 2686702c1d5768f661dac610c96089eb02e360d7..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/factorization/ops/clustering_ops.cc +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2016 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not -// use this file except in compliance with the License. You may obtain a copy -// of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations under -// the License. -// ============================================================================== - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -REGISTER_OP("KmeansPlusPlusInitialization") - .Input("points: float32") - .Input("num_to_sample: int64") - .Input("seed: int64") - .Input("num_retries_per_sample: int64") - .Output("samples: float32") - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"( -Selects num_to_sample rows of input using the KMeans++ criterion. - -Rows of points are assumed to be input points. One row is selected at random. -Subsequent rows are sampled with probability proportional to the squared L2 -distance from the nearest row selected thus far till num_to_sample rows have -been sampled. - -points: Matrix of shape (n, d). Rows are assumed to be input points. -num_to_sample: Scalar. The number of rows to sample. This value must not be - larger than n. -seed: Scalar. Seed for initializing the random number generator. -num_retries_per_sample: Scalar. For each row that is sampled, this parameter - specifies the number of additional points to draw from the current - distribution before selecting the best. If a negative value is specified, a - heuristic is used to sample O(log(num_to_sample)) additional points. -samples: Matrix of shape (num_to_sample, d). The sampled rows. -)"); - -REGISTER_OP("KMC2ChainInitialization") - .Input("distances: float32") - .Input("seed: int64") - .Output("index: int64") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"( -Returns the index of a data point that should be added to the seed set. - -Entries in distances are assumed to be squared distances of candidate points to -the already sampled centers in the seed set. The op constructs one Markov chain -of the k-MC^2 algorithm and returns the index of one candidate point to be added -as an additional cluster center. - -distances: Vector with squared distances to the closest previously sampled - cluster center for each candidate point. -seed: Scalar. Seed for initializing the random number generator. -index: Scalar with the index of the sampled point. -)"); - -REGISTER_OP("NearestNeighbors") - .Input("points: float32") - .Input("centers: float32") - .Input("k: int64") - .Output("nearest_center_indices: int64") - .Output("nearest_center_distances: float32") - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"( -Selects the k nearest centers for each point. - -Rows of points are assumed to be input points. Rows of centers are assumed to be -the list of candidate centers. For each point, the k centers that have least L2 -distance to it are computed. - -points: Matrix of shape (n, d). Rows are assumed to be input points. -centers: Matrix of shape (m, d). Rows are assumed to be centers. -k: Scalar. Number of nearest centers to return for each point. If k is larger - than m, then only m centers are returned. -nearest_center_indices: Matrix of shape (n, min(m, k)). Each row contains the - indices of the centers closest to the corresponding point, ordered by - increasing distance. -nearest_center_distances: Matrix of shape (n, min(m, k)). Each row contains the - squared L2 distance to the corresponding center in nearest_center_indices. -)"); - -} // namespace tensorflow diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py index 84e80791f4991ad2b67d0a00ee1e00cf0d0daadc..d48b89cbacce34781819010addbcbd0ba66f9873 100644 --- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -18,28 +18,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.factorization.python.ops import gen_clustering_ops -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.factorization.python.ops.gen_clustering_ops import * -# pylint: enable=wildcard-import -from tensorflow.contrib.util import loader from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_clustering_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_impl from tensorflow.python.ops import random_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.embedding_ops import embedding_lookup -from tensorflow.python.platform import resource_loader - -_clustering_ops = loader.load_op_library( - resource_loader.get_path_to_datafile('_clustering_ops.so')) +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.python.ops.gen_clustering_ops import * +# pylint: enable=wildcard-import # Euclidean distance between vectors U and V is defined as \\(||U - V||_F\\) # which is the square root of the sum of the absolute squares of the elements diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py index d365ad111760247fc18b730657390f07ba6b865e..9f0664dfe5ba7a098b6976388d1cf737dafb4842 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py @@ -314,8 +314,7 @@ class GmmAlgorithm(object): # reparametrization of variance parameters. det_expanded = math_ops.reduce_sum( math_ops.log(self._covs + 1e-3), 1, keepdims=True) - diff = shard - self._means - x2 = math_ops.square(diff) + x2 = math_ops.squared_difference(shard, self._means) cov_expanded = array_ops.expand_dims(1.0 / (self._covs + 1e-3), 2) # num_classes X num_examples x2_cov = math_ops.matmul(x2, cov_expanded) diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index 4c1d1a29f20b5574b63cf87ecf62db95f92902cd..8fc5f1cfe7800653ef1e43c6d40d1a66e34f2106 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -6,7 +6,7 @@ package( licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "tf_py_test") py_library( name = "feature_column_py", @@ -37,13 +37,13 @@ py_library( ], ) -py_test( +tf_py_test( name = "sequence_feature_column_test", srcs = ["python/feature_column/sequence_feature_column_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ + additional_deps = [ ":sequence_feature_column", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", @@ -53,17 +53,14 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", "//tensorflow/python/feature_column:feature_column_py", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", ], + tags = ["no_pip"], ) -py_test( +tf_py_test( name = "sequence_feature_column_integration_test", srcs = ["python/feature_column/sequence_feature_column_integration_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ + additional_deps = [ ":sequence_feature_column", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", @@ -73,6 +70,7 @@ py_test( "//tensorflow/python/feature_column:feature_column_py", "//tensorflow/python/keras:layers", ], + tags = ["no_pip"], ) py_library( @@ -94,14 +92,13 @@ py_library( ], ) -py_test( +tf_py_test( name = "sequence_feature_column_v2_test", srcs = ["python/feature_column/sequence_feature_column_v2_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":sequence_feature_column", + additional_deps = [ ":sequence_feature_column_v2", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", @@ -112,7 +109,23 @@ py_test( "//tensorflow/python:training", "//tensorflow/python/feature_column:feature_column_py", "//tensorflow/python/feature_column:feature_column_v2_test", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", + ], + tags = ["no_pip"], +) + +py_test( + name = "sequence_feature_column_v2_integration_test", + srcs = ["python/feature_column/sequence_feature_column_v2_integration_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":sequence_feature_column_v2", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python/feature_column:feature_column_py", + "//tensorflow/python/keras:layers", ], ) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py index 83b93ec332044f754f9dcde8d7c5c19b26e53a4a..2f4bda194a41242167e0abfcaeac5044f6026f85 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py @@ -27,6 +27,7 @@ import collections from tensorflow.python.feature_column import feature_column as fc_old from tensorflow.python.feature_column import feature_column_lib as fc +from tensorflow.python.feature_column import feature_column_v2 as fc_v2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -34,107 +35,115 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import sparse_ops -from tensorflow.python.ops import variable_scope # pylint: disable=protected-access -def sequence_input_layer( - features, - feature_columns, - weight_collections=None, - trainable=True): - """"Builds input layer for sequence input. +class SequenceFeatures(fc_v2._BaseFeaturesLayer): + """A layer for sequence input. - All `feature_columns` must be sequence dense columns with the same - `sequence_length`. The output of this method can be fed into sequence - networks, such as RNN. + All `feature_columns` must be sequence dense columns with the same + `sequence_length`. The output of this method can be fed into sequence + networks, such as RNN. - The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`. - `T` is the maximum sequence length for this batch, which could differ from - batch to batch. + The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ from + batch to batch. - If multiple `feature_columns` are given with `Di` `num_elements` each, their - outputs are concatenated. So, the final `Tensor` has shape - `[batch_size, T, D0 + D1 + ... + Dn]`. + If multiple `feature_columns` are given with `Di` `num_elements` each, their + outputs are concatenated. So, the final `Tensor` has shape + `[batch_size, T, D0 + D1 + ... + Dn]`. - Example: + Example: - ```python - rating = sequence_numeric_column('rating') - watches = sequence_categorical_column_with_identity( - 'watches', num_buckets=1000) - watches_embedding = embedding_column(watches, dimension=10) - columns = [rating, watches] + ```python + rating = sequence_numeric_column('rating') + watches = sequence_categorical_column_with_identity( + 'watches', num_buckets=1000) + watches_embedding = embedding_column(watches, dimension=10) + columns = [rating, watches] - features = tf.parse_example(..., features=make_parse_example_spec(columns)) - input_layer, sequence_length = sequence_input_layer(features, columns) + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + sequence_input_layer = SequenceFeatures(columns) + sequence_input, sequence_length = sequence_input_layer(features) + sequence_length_mask = tf.sequence_mask(sequence_length) - rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) - outputs, state = tf.nn.dynamic_rnn( - rnn_cell, inputs=input_layer, sequence_length=sequence_length) - ``` + rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size) + rnn_layer = tf.keras.layers.RNN(rnn_cell) + outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask) + ``` + """ - Args: - features: A dict mapping keys to tensors. - feature_columns: An iterable of dense sequence columns. Valid columns are - - `embedding_column` that wraps a `sequence_categorical_column_with_*` - - `sequence_numeric_column`. - weight_collections: A list of collection names to which the Variable will be - added. Note that variables will also be added to collections - `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`. - trainable: If `True` also add the variable to the graph collection - `GraphKeys.TRAINABLE_VARIABLES`. + def __init__( + self, + feature_columns, + trainable=True, + name=None, + **kwargs): + """"Constructs a SequenceFeatures layer. - Returns: - An `(input_layer, sequence_length)` tuple where: - - input_layer: A float `Tensor` of shape `[batch_size, T, D]`. - `T` is the maximum sequence length for this batch, which could differ - from batch to batch. `D` is the sum of `num_elements` for all - `feature_columns`. - - sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence - length for each example. + Args: + feature_columns: An iterable of dense sequence columns. Valid columns are + - `embedding_column` that wraps a `sequence_categorical_column_with_*` + - `sequence_numeric_column`. + trainable: Boolean, whether the layer's variables will be updated via + gradient descent during training. + name: Name to give to the SequenceFeatures. + **kwargs: Keyword arguments to construct a layer. + + Raises: + ValueError: If any of the `feature_columns` is not a + `SequenceDenseColumn`. + """ + super(SequenceFeatures, self).__init__( + feature_columns=feature_columns, + trainable=trainable, + name=name, + expected_column_type=fc_v2.SequenceDenseColumn, + **kwargs) - Raises: - ValueError: If any of the `feature_columns` is the wrong type. - """ - feature_columns = fc_old._normalize_feature_columns(feature_columns) - for c in feature_columns: - if not isinstance(c, fc_old._SequenceDenseColumn): - raise ValueError( - 'All feature_columns must be of type _SequenceDenseColumn. ' - 'You can wrap a sequence_categorical_column with an embedding_column ' - 'or indicator_column. ' - 'Given (type {}): {}'.format(type(c), c)) - - with variable_scope.variable_scope( - None, default_name='sequence_input_layer', values=features.values()): - builder = fc_old._LazyBuilder(features) + def _target_shape(self, input_shape, total_elements): + return (input_shape[0], input_shape[1], total_elements) + + def call(self, features): + """Returns sequence input corresponding to the `feature_columns`. + + Args: + features: A dict mapping keys to tensors. + + Returns: + An `(input_layer, sequence_length)` tuple where: + - input_layer: A float `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ + from batch to batch. `D` is the sum of `num_elements` for all + `feature_columns`. + - sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence + length for each example. + + Raises: + ValueError: If features are not a dictionary. + """ + if not isinstance(features, dict): + raise ValueError('We expected a dictionary here. Instead we got: ', + features) + transformation_cache = fc.FeatureTransformationCache(features) output_tensors = [] sequence_lengths = [] - ordered_columns = [] - - for column in sorted(feature_columns, key=lambda x: x.name): - ordered_columns.append(column) - with variable_scope.variable_scope( - None, default_name=column._var_scope_name): - dense_tensor, sequence_length = column._get_sequence_dense_tensor( - builder, - weight_collections=weight_collections, - trainable=trainable) + + for column in self._feature_columns: + with ops.name_scope(column.name): + dense_tensor, sequence_length = column.get_sequence_dense_tensor( + transformation_cache, self._state_manager) # Flattens the final dimension to produce a 3D Tensor. - num_elements = column._variable_shape.num_elements() - shape = array_ops.shape(dense_tensor) - target_shape = [shape[0], shape[1], num_elements] - output_tensors.append( - array_ops.reshape(dense_tensor, shape=target_shape)) + output_tensors.append(self._process_dense_tensor(column, dense_tensor)) sequence_lengths.append(sequence_length) - fc_old._verify_static_batch_size_equality(output_tensors, ordered_columns) - fc_old._verify_static_batch_size_equality(sequence_lengths, ordered_columns) + # Check and process sequence lengths. + fc_v2._verify_static_batch_size_equality(sequence_lengths, + self._feature_columns) sequence_length = _assert_all_equal_and_return(sequence_lengths) - return array_ops.concat(output_tensors, -1), sequence_length + return self._verify_and_concat_tensors(output_tensors), sequence_length def concatenate_context_input(context_input, sequence_input): @@ -203,12 +212,13 @@ def sequence_categorical_column_with_identity( columns = [watches_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - sequence_feature_layer = SequenceFeatureLayer(columns) - input_layer, sequence_length = sequence_feature_layer(features) + sequence_feature_layer = SequenceFeatures(columns) + sequence_input, sequence_length = sequence_feature_layer(features) + sequence_length_mask = tf.sequence_mask(sequence_length) - rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) - outputs, state = tf.nn.dynamic_rnn( - rnn_cell, inputs=input_layer, sequence_length=sequence_length) + rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size) + rnn_layer = tf.keras.layers.RNN(rnn_cell) + outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask) ``` Args: @@ -250,12 +260,13 @@ def sequence_categorical_column_with_hash_bucket( columns = [tokens_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - sequence_feature_layer = SequenceFeatureLayer(columns) - input_layer, sequence_length = sequence_feature_layer(features) + sequence_feature_layer = SequenceFeatures(columns) + sequence_input, sequence_length = sequence_feature_layer(features) + sequence_length_mask = tf.sequence_mask(sequence_length) - rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) - outputs, state = tf.nn.dynamic_rnn( - rnn_cell, inputs=input_layer, sequence_length=sequence_length) + rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size) + rnn_layer = tf.keras.layers.RNN(rnn_cell) + outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask) ``` Args: @@ -296,12 +307,13 @@ def sequence_categorical_column_with_vocabulary_file( columns = [states_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - sequence_feature_layer = SequenceFeatureLayer(columns) - input_layer, sequence_length = sequence_feature_layer(features) + sequence_feature_layer = SequenceFeatures(columns) + sequence_input, sequence_length = sequence_feature_layer(features) + sequence_length_mask = tf.sequence_mask(sequence_length) - rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) - outputs, state = tf.nn.dynamic_rnn( - rnn_cell, inputs=input_layer, sequence_length=sequence_length) + rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size) + rnn_layer = tf.keras.layers.RNN(rnn_cell) + outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask) ``` Args: @@ -358,12 +370,13 @@ def sequence_categorical_column_with_vocabulary_list( columns = [colors_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - sequence_feature_layer = SequenceFeatureLayer(columns) - input_layer, sequence_length = sequence_feature_layer(features) + sequence_feature_layer = SequenceFeatures(columns) + sequence_input, sequence_length = sequence_feature_layer(features) + sequence_length_mask = tf.sequence_mask(sequence_length) - rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) - outputs, state = tf.nn.dynamic_rnn( - rnn_cell, inputs=input_layer, sequence_length=sequence_length) + rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size) + rnn_layer = tf.keras.layers.RNN(rnn_cell) + outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask) ``` Args: @@ -415,12 +428,13 @@ def sequence_numeric_column( columns = [temperature] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - sequence_feature_layer = SequenceFeatureLayer(columns) - input_layer, sequence_length = sequence_feature_layer(features) + sequence_feature_layer = SequenceFeatures(columns) + sequence_input, sequence_length = sequence_feature_layer(features) + sequence_length_mask = tf.sequence_mask(sequence_length) - rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) - outputs, state = tf.nn.dynamic_rnn( - rnn_cell, inputs=input_layer, sequence_length=sequence_length) + rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size) + rnn_layer = tf.keras.layers.RNN(rnn_cell) + outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask) ``` Args: @@ -445,7 +459,7 @@ def sequence_numeric_column( ValueError: if any dimension in shape is not a positive integer. ValueError: if `dtype` is not convertible to `tf.float32`. """ - shape = fc_old._check_shape(shape=shape, key=key) + shape = fc_v2._check_shape(shape=shape, key=key) if not (dtype.is_integer or dtype.is_floating): raise ValueError('dtype must be convertible to float. ' 'dtype: {}, key: {}'.format(dtype, key)) @@ -540,8 +554,10 @@ class SequenceNumericColumn( # For the 2D case, the raw values are grouped according to num_elements; # for the 3D case, the grouping happens in the third dimension, and # sequence length is not affected. - num_elements = (self.variable_shape.num_elements() - if sp_tensor.shape.ndims == 2 else 1) + if sp_tensor.shape.ndims == 2: + num_elements = self.variable_shape.num_elements() + else: + num_elements = 1 seq_length = fc_old._sequence_length_from_sparse_tensor( sp_tensor, num_elements=num_elements) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_integration_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_integration_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1b165a620ae67e855400eb297ec17db80eac7937 --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_integration_test.py @@ -0,0 +1,283 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Integration test for sequence feature columns with SequenceExamples.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import string +import tempfile + +from google.protobuf import text_format + +from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column_v2 as sfc +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.feature_column import feature_column_v2 as fc +from tensorflow.python.keras.layers import recurrent +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class SequenceFeatureColumnIntegrationTest(test.TestCase): + + def _make_sequence_example(self): + example = example_pb2.SequenceExample() + example.context.feature['int_ctx'].int64_list.value.extend([5]) + example.context.feature['float_ctx'].float_list.value.extend([123.6]) + for val in range(0, 10, 2): + feat = feature_pb2.Feature() + feat.int64_list.value.extend([val] * val) + example.feature_lists.feature_list['int_list'].feature.extend([feat]) + for val in range(1, 11, 2): + feat = feature_pb2.Feature() + feat.bytes_list.value.extend([compat.as_bytes(str(val))] * val) + example.feature_lists.feature_list['str_list'].feature.extend([feat]) + + return example + + def _build_feature_columns(self): + col = fc.categorical_column_with_identity('int_ctx', num_buckets=100) + ctx_cols = [ + fc.embedding_column(col, dimension=10), + fc.numeric_column('float_ctx') + ] + + identity_col = sfc.sequence_categorical_column_with_identity( + 'int_list', num_buckets=10) + bucket_col = sfc.sequence_categorical_column_with_hash_bucket( + 'bytes_list', hash_bucket_size=100) + seq_cols = [ + fc.embedding_column(identity_col, dimension=10), + fc.embedding_column(bucket_col, dimension=20) + ] + + return ctx_cols, seq_cols + + def test_sequence_example_into_input_layer(self): + examples = [_make_sequence_example().SerializeToString()] * 100 + ctx_cols, seq_cols = self._build_feature_columns() + + def _parse_example(example): + ctx, seq = parsing_ops.parse_single_sequence_example( + example, + context_features=fc.make_parse_example_spec_v2(ctx_cols), + sequence_features=fc.make_parse_example_spec_v2(seq_cols)) + ctx.update(seq) + return ctx + + ds = dataset_ops.Dataset.from_tensor_slices(examples) + ds = ds.map(_parse_example) + ds = ds.batch(20) + + # Test on a single batch + features = ds.make_one_shot_iterator().get_next() + + # Tile the context features across the sequence features + sequence_input_layer = sfc.SequenceFeatures(seq_cols) + seq_layer, _ = sequence_input_layer(features) + input_layer = fc.DenseFeatures(ctx_cols) + ctx_layer = input_layer(features) + input_layer = sfc.concatenate_context_input(ctx_layer, seq_layer) + + rnn_layer = recurrent.RNN(recurrent.SimpleRNNCell(10)) + output = rnn_layer(input_layer) + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + features_r = sess.run(features) + self.assertAllEqual(features_r['int_list'].dense_shape, [20, 3, 6]) + + output_r = sess.run(output) + self.assertAllEqual(output_r.shape, [20, 10]) + + +class SequenceExampleParsingTest(test.TestCase): + + def test_seq_ex_in_sequence_categorical_column_with_identity(self): + self._test_parsed_sequence_example( + 'int_list', sfc.sequence_categorical_column_with_identity, + 10, [3, 6], [2, 4, 6]) + + def test_seq_ex_in_sequence_categorical_column_with_hash_bucket(self): + self._test_parsed_sequence_example( + 'bytes_list', sfc.sequence_categorical_column_with_hash_bucket, + 10, [3, 4], [compat.as_bytes(x) for x in 'acg']) + + def test_seq_ex_in_sequence_categorical_column_with_vocabulary_list(self): + self._test_parsed_sequence_example( + 'bytes_list', sfc.sequence_categorical_column_with_vocabulary_list, + list(string.ascii_lowercase), [3, 4], + [compat.as_bytes(x) for x in 'acg']) + + def test_seq_ex_in_sequence_categorical_column_with_vocabulary_file(self): + _, fname = tempfile.mkstemp() + with open(fname, 'w') as f: + f.write(string.ascii_lowercase) + self._test_parsed_sequence_example( + 'bytes_list', sfc.sequence_categorical_column_with_vocabulary_file, + fname, [3, 4], [compat.as_bytes(x) for x in 'acg']) + + def _test_parsed_sequence_example( + self, col_name, col_fn, col_arg, shape, values): + """Helper function to check that each FeatureColumn parses correctly. + + Args: + col_name: string, name to give to the feature column. Should match + the name that the column will parse out of the features dict. + col_fn: function used to create the feature column. For example, + sequence_numeric_column. + col_arg: second arg that the target feature column is expecting. + shape: the expected dense_shape of the feature after parsing into + a SparseTensor. + values: the expected values at index [0, 2, 6] of the feature + after parsing into a SparseTensor. + """ + example = _make_sequence_example() + columns = [ + fc.categorical_column_with_identity('int_ctx', num_buckets=100), + fc.numeric_column('float_ctx'), + col_fn(col_name, col_arg) + ] + context, seq_features = parsing_ops.parse_single_sequence_example( + example.SerializeToString(), + context_features=fc.make_parse_example_spec_v2(columns[:2]), + sequence_features=fc.make_parse_example_spec_v2(columns[2:])) + + with self.cached_session() as sess: + ctx_result, seq_result = sess.run([context, seq_features]) + self.assertEqual(list(seq_result[col_name].dense_shape), shape) + self.assertEqual( + list(seq_result[col_name].values[[0, 2, 6]]), values) + self.assertEqual(list(ctx_result['int_ctx'].dense_shape), [1]) + self.assertEqual(ctx_result['int_ctx'].values[0], 5) + self.assertEqual(list(ctx_result['float_ctx'].shape), [1]) + self.assertAlmostEqual(ctx_result['float_ctx'][0], 123.6, places=1) + + +_SEQ_EX_PROTO = """ +context { + feature { + key: "float_ctx" + value { + float_list { + value: 123.6 + } + } + } + feature { + key: "int_ctx" + value { + int64_list { + value: 5 + } + } + } +} +feature_lists { + feature_list { + key: "bytes_list" + value { + feature { + bytes_list { + value: "a" + } + } + feature { + bytes_list { + value: "b" + value: "c" + } + } + feature { + bytes_list { + value: "d" + value: "e" + value: "f" + value: "g" + } + } + } + } + feature_list { + key: "float_list" + value { + feature { + float_list { + value: 1.0 + } + } + feature { + float_list { + value: 3.0 + value: 3.0 + value: 3.0 + } + } + feature { + float_list { + value: 5.0 + value: 5.0 + value: 5.0 + value: 5.0 + value: 5.0 + } + } + } + } + feature_list { + key: "int_list" + value { + feature { + int64_list { + value: 2 + value: 2 + } + } + feature { + int64_list { + value: 4 + value: 4 + value: 4 + value: 4 + } + } + feature { + int64_list { + value: 6 + value: 6 + value: 6 + value: 6 + value: 6 + value: 6 + } + } + } + } +} +""" + + +def _make_sequence_example(): + example = example_pb2.SequenceExample() + return text_format.Parse(_SEQ_EX_PROTO, example) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py index be012a87690c24c6d9b7808790393e1aa6d01211..a1feaddcc00d5fac86dca3138dfa1c6314bb6a8b 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py @@ -22,9 +22,7 @@ import os from absl.testing import parameterized import numpy as np -from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc_old from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column_v2 as sfc -from tensorflow.python.feature_column import feature_column as fc_old from tensorflow.python.feature_column import feature_column_lib as fc from tensorflow.python.feature_column.feature_column_v2_test import _TestStateManager from tensorflow.python.framework import dtypes @@ -32,13 +30,15 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test from tensorflow.python.training import monitored_session -class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): +class SequenceFeaturesTest(test.TestCase, parameterized.TestCase): @parameterized.named_parameters( {'testcase_name': '2D', @@ -111,29 +111,27 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc_old._embedding_column( + embedding_column_a = fc.embedding_column( categorical_column_a, dimension=embedding_dimension_a, initializer=_get_initializer(embedding_dimension_a, embedding_values_a)) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - embedding_column_b = fc_old._embedding_column( + embedding_column_b = fc.embedding_column( categorical_column_b, dimension=embedding_dimension_b, initializer=_get_initializer(embedding_dimension_b, embedding_values_b)) - input_layer, sequence_length = sfc.sequence_input_layer( - features={ - 'aaa': sparse_input_a, - 'bbb': sparse_input_b, - }, - # Test that columns are reordered alphabetically. - feature_columns=[embedding_column_b, embedding_column_a]) + # Test that columns are reordered alphabetically. + sequence_input_layer = sfc.SequenceFeatures( + [embedding_column_b, embedding_column_a]) + input_layer, sequence_length = sequence_input_layer({ + 'aaa': sparse_input_a, 'bbb': sparse_input_b,}) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertCountEqual( - ('sequence_input_layer/aaa_embedding/embedding_weights:0', - 'sequence_input_layer/bbb_embedding/embedding_weights:0'), + ('sequence_features/aaa_embedding/embedding_weights:0', + 'sequence_features/bbb_embedding/embedding_weights:0'), tuple([v.name for v in global_vars])) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(embedding_values_a, global_vars[0].eval(session=sess)) @@ -152,18 +150,17 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc_old._categorical_column_with_identity( + categorical_column_a = fc.categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc_old._embedding_column( + embedding_column_a = fc.embedding_column( categorical_column_a, dimension=2) with self.assertRaisesRegexp( ValueError, r'In embedding_column: aaa_embedding\. categorical_column must be of ' - r'type _SequenceCategoricalColumn to use sequence_input_layer\.'): - _, _ = sfc.sequence_input_layer( - features={'aaa': sparse_input}, - feature_columns=[embedding_column_a]) + r'type SequenceCategoricalColumn to use SequenceFeatures\.'): + sequence_input_layer = sfc.SequenceFeatures([embedding_column_a]) + _, _ = sequence_input_layer({'aaa': sparse_input}) def test_shared_embedding_column(self): vocabulary_size = 3 @@ -210,21 +207,18 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) # Test that columns are reordered alphabetically. - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_b, categorical_column_a], dimension=embedding_dimension, initializer=_get_initializer(embedding_dimension, embedding_values)) - input_layer, sequence_length = sfc.sequence_input_layer( - features={ - 'aaa': sparse_input_a, - 'bbb': sparse_input_b, - }, - feature_columns=shared_embedding_columns) + sequence_input_layer = sfc.SequenceFeatures(shared_embedding_columns) + input_layer, sequence_length = sequence_input_layer({ + 'aaa': sparse_input_a, 'bbb': sparse_input_b}) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertCountEqual( - ('sequence_input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), + ('aaa_bbb_shared_embedding:0',), tuple([v.name for v in global_vars])) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) @@ -248,23 +242,20 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc_old._categorical_column_with_identity( + categorical_column_a = fc.categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - categorical_column_b = fc_old._categorical_column_with_identity( + categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=2) with self.assertRaisesRegexp( ValueError, r'In embedding_column: aaa_shared_embedding\. categorical_column must ' - r'be of type _SequenceCategoricalColumn to use sequence_input_layer\.'): - _, _ = sfc.sequence_input_layer( - features={ - 'aaa': sparse_input_a, - 'bbb': sparse_input_b - }, - feature_columns=shared_embedding_columns) + r'be of type SequenceCategoricalColumn to use SequenceFeatures\.'): + sequence_input_layer = sfc.SequenceFeatures(shared_embedding_columns) + _, _ = sequence_input_layer({'aaa': sparse_input_a, + 'bbb': sparse_input_b}) @parameterized.named_parameters( {'testcase_name': '2D', @@ -319,17 +310,15 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size_a) - indicator_column_a = fc_old._indicator_column(categorical_column_a) + indicator_column_a = fc.indicator_column(categorical_column_a) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size_b) - indicator_column_b = fc_old._indicator_column(categorical_column_b) - input_layer, sequence_length = sfc.sequence_input_layer( - features={ - 'aaa': sparse_input_a, - 'bbb': sparse_input_b, - }, - # Test that columns are reordered alphabetically. - feature_columns=[indicator_column_b, indicator_column_a]) + indicator_column_b = fc.indicator_column(categorical_column_b) + # Test that columns are reordered alphabetically. + sequence_input_layer = sfc.SequenceFeatures( + [indicator_column_b, indicator_column_a]) + input_layer, sequence_length = sequence_input_layer({ + 'aaa': sparse_input_a, 'bbb': sparse_input_b}) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) @@ -346,17 +335,16 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc_old._categorical_column_with_identity( + categorical_column_a = fc.categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column_a = fc_old._indicator_column(categorical_column_a) + indicator_column_a = fc.indicator_column(categorical_column_a) with self.assertRaisesRegexp( ValueError, r'In indicator_column: aaa_indicator\. categorical_column must be of ' - r'type _SequenceCategoricalColumn to use sequence_input_layer\.'): - _, _ = sfc.sequence_input_layer( - features={'aaa': sparse_input}, - feature_columns=[indicator_column_a]) + r'type SequenceCategoricalColumn to use SequenceFeatures\.'): + sequence_input_layer = sfc.SequenceFeatures([indicator_column_a]) + _, _ = sequence_input_layer({'aaa': sparse_input}) @parameterized.named_parameters( {'testcase_name': '2D', @@ -375,7 +363,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): # feature 0, ids [[20, 3], [5]] # feature 1, ids [[3], [8]] 'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), - 'values': (20, 3, 5., 3., 8.), + 'values': (20., 3., 5., 3., 8.), 'dense_shape': (2, 2, 2)}, 'expected_input_layer': [ [[20.], [3.], [5.], [0.]], @@ -386,11 +374,10 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): self, sparse_input_args, expected_input_layer, expected_sequence_length): sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) - numeric_column = sfc_old.sequence_numeric_column('aaa') + numeric_column = sfc.sequence_numeric_column('aaa') - input_layer, sequence_length = sfc.sequence_input_layer( - features={'aaa': sparse_input}, - feature_columns=[numeric_column]) + sequence_input_layer = sfc.SequenceFeatures([numeric_column]) + input_layer, sequence_length = sequence_input_layer({'aaa': sparse_input}) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) @@ -428,14 +415,13 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): ) def test_numeric_column_multi_dim( self, sparse_input_args, expected_input_layer, expected_sequence_length): - """Tests sequence_input_layer for multi-dimensional numeric_column.""" + """Tests SequenceFeatures for multi-dimensional numeric_column.""" sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) - numeric_column = sfc_old.sequence_numeric_column('aaa', shape=(2, 2)) + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) - input_layer, sequence_length = sfc.sequence_input_layer( - features={'aaa': sparse_input}, - feature_columns=[numeric_column]) + sequence_input_layer = sfc.SequenceFeatures([numeric_column]) + input_layer, sequence_length = sequence_input_layer({'aaa': sparse_input}) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) @@ -454,22 +440,20 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): indices=((0, 0), (1, 0)), values=(1., 10.), dense_shape=(2, 2)) - numeric_column_a = sfc_old.sequence_numeric_column('aaa') - numeric_column_b = sfc_old.sequence_numeric_column('bbb') + numeric_column_a = sfc.sequence_numeric_column('aaa') + numeric_column_b = sfc.sequence_numeric_column('bbb') - _, sequence_length = sfc.sequence_input_layer( - features={ - 'aaa': sparse_input_a, - 'bbb': sparse_input_b, - }, - feature_columns=[numeric_column_a, numeric_column_b]) + sequence_input_layer = sfc.SequenceFeatures( + [numeric_column_a, numeric_column_b]) + _, sequence_length = sequence_input_layer({ + 'aaa': sparse_input_a, 'bbb': sparse_input_b}) with monitored_session.MonitoredSession() as sess: with self.assertRaisesRegexp( errors.InvalidArgumentError, r'\[Condition x == y did not hold element-wise:\] ' - r'\[x \(sequence_input_layer/aaa/sequence_length:0\) = \] \[2 1\] ' - r'\[y \(sequence_input_layer/bbb/sequence_length:0\) = \] \[1 1\]'): + r'\[x \(sequence_features/aaa/sequence_length:0\) = \] \[2 1\] ' + r'\[y \(sequence_features/bbb/sequence_length:0\) = \] \[1 1\]'): sess.run(sequence_length) @parameterized.named_parameters( @@ -497,11 +481,10 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): self, sparse_input_args, expected_shape): """Tests that we return a known static shape when we have one.""" sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) - numeric_column = sfc_old.sequence_numeric_column('aaa', shape=(2, 2)) + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) - input_layer, _ = sfc.sequence_input_layer( - features={'aaa': sparse_input}, - feature_columns=[numeric_column]) + sequence_input_layer = sfc.SequenceFeatures([numeric_column]) + input_layer, _ = sequence_input_layer({'aaa': sparse_input}) shape = input_layer.get_shape() self.assertEqual(shape, expected_shape) @@ -534,13 +517,49 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=3) - indicator_column = fc_old._indicator_column(categorical_column) + indicator_column = fc.indicator_column(categorical_column) - input_layer, _ = sfc.sequence_input_layer( - features={'aaa': sparse_input}, feature_columns=[indicator_column]) + sequence_input_layer = sfc.SequenceFeatures([indicator_column]) + input_layer, _ = sequence_input_layer({'aaa': sparse_input}) shape = input_layer.get_shape() self.assertEqual(shape, expected_shape) + def test_compute_output_shape(self): + price1 = sfc.sequence_numeric_column('price1', shape=2) + price2 = sfc.sequence_numeric_column('price2') + with ops.Graph().as_default(): + features = { + 'price1': sparse_tensor.SparseTensor( + indices=[[0, 0, 0], [0, 0, 1], + [0, 1, 0], [0, 1, 1], + [1, 0, 0], [1, 0, 1], + [2, 0, 0], [2, 0, 1], + [3, 0, 0], [3, 0, 1]], + values=[0., 1., 10., 11., 100., 101., 200., 201., 300., 301.], + dense_shape=(4, 3, 2)), + 'price2': sparse_tensor.SparseTensor( + indices=[[0, 0], + [0, 1], + [1, 0], + [2, 0], + [3, 0]], + values=[10., 11., 20., 30., 40.], + dense_shape=(4, 3))} + sequence_features = sfc.SequenceFeatures([price1, price2]) + seq_input, seq_len = sequence_features(features) + self.assertEqual( + sequence_features.compute_output_shape((None, None)), + (None, None, 3)) + self.evaluate(variables_lib.global_variables_initializer()) + self.evaluate(lookup_ops.tables_initializer()) + + self.assertAllClose([[[0., 1., 10.], [10., 11., 11.], [0., 0., 0.]], + [[100., 101., 20.], [0., 0., 0.], [0., 0., 0.]], + [[200., 201., 30.], [0., 0., 0.], [0., 0., 0.]], + [[300., 301., 40.], [0., 0., 0.], [0., 0., 0.]]], + self.evaluate(seq_input)) + self.assertAllClose([2, 1, 1, 1], self.evaluate(seq_len)) + class ConcatenateContextInputTest(test.TestCase, parameterized.TestCase): """Tests the utility fn concatenate_context_input.""" @@ -605,8 +624,8 @@ class ConcatenateContextInputTest(test.TestCase, parameterized.TestCase): sfc.concatenate_context_input(context_input, seq_input) -class InputLayerTest(test.TestCase): - """Tests input_layer with sequence feature columns.""" +class DenseFeaturesTest(test.TestCase): + """Tests DenseFeatures with sequence feature columns.""" def test_embedding_column(self): """Tests that error is raised for sequence embedding column.""" @@ -620,16 +639,15 @@ class InputLayerTest(test.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc_old._embedding_column( + embedding_column_a = fc.embedding_column( categorical_column_a, dimension=2) with self.assertRaisesRegexp( ValueError, r'In embedding_column: aaa_embedding\. categorical_column must not be ' - r'of type _SequenceCategoricalColumn\.'): - _ = fc_old.input_layer( - features={'aaa': sparse_input}, - feature_columns=[embedding_column_a]) + r'of type SequenceCategoricalColumn\.'): + input_layer = fc.DenseFeatures([embedding_column_a]) + _ = input_layer({'aaa': sparse_input}) def test_indicator_column(self): """Tests that error is raised for sequence indicator column.""" @@ -643,15 +661,14 @@ class InputLayerTest(test.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column_a = fc_old._indicator_column(categorical_column_a) + indicator_column_a = fc.indicator_column(categorical_column_a) with self.assertRaisesRegexp( ValueError, r'In indicator_column: aaa_indicator\. categorical_column must not be ' - r'of type _SequenceCategoricalColumn\.'): - _ = fc_old.input_layer( - features={'aaa': sparse_input}, - feature_columns=[indicator_column_a]) + r'of type SequenceCategoricalColumn\.'): + input_layer = fc.DenseFeatures([indicator_column_a]) + _ = input_layer({'aaa': sparse_input}) def _assert_sparse_tensor_value(test_case, expected, actual): @@ -946,7 +963,7 @@ class SequenceEmbeddingColumnTest( embedding_column, {'aaa': inputs}) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( + self.assertCountEqual( ('embedding_weights:0',), tuple([v.name for v in global_vars])) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index dad50a3a73085526f65bd87c3d8549ceb75b3af4..3f6dbe0cbdeeae5e2107755f80bcfe5f7fc310e4 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -50,6 +50,8 @@ tf_custom_op_py_library( visibility = [ "//learning/brain:__subpackages__", "//tensorflow:__subpackages__", + "//tensorflow_estimator:__subpackages__", + "//tensorflow_model_optimization:__subpackages__", "//video/youtube/personalization:__subpackages__", ], deps = [ diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index e72e50585a3861d4527b66f89e1659d76c85960a..3784631dcbfbeb215b6c695e4b6f1bbd02fa708c 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -130,17 +130,21 @@ _allowed_symbols = ['nest'] _nest_allowed_symbols = [ 'assert_same_structure', 'is_sequence', + 'is_sequence_or_composite', 'flatten', 'flatten_dict_items', 'pack_sequence_as', 'map_structure', 'map_structure_with_paths', + 'map_structure_with_tuple_paths', 'assert_shallow_structure', 'flatten_up_to', 'map_structure_up_to', + 'map_structure_with_tuple_paths_up_to', 'get_traverse_shallow_structure', 'yield_flat_paths', 'flatten_with_joined_string_paths', + 'flatten_with_tuple_paths', ] remove_undocumented(nest.__name__, allowed_exception_list=_nest_allowed_symbols) diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD index 57a5bfbf43c915775c6b0ef05baac19581213a09..f65f450eba49163c319af54ec2bd7f6b61e34c1e 100644 --- a/tensorflow/contrib/fused_conv/BUILD +++ b/tensorflow/contrib/fused_conv/BUILD @@ -171,6 +171,7 @@ cuda_py_test( main = "python/ops/fused_conv2d_bias_activation_benchmark.py", tags = [ "manual", # TODO(b/117128481): re-enable after fixing OSS build + "nogpu", "requires-gpu-sm70", ], ) diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index c541c71f996c7a1b36cf28ae9a1783f8dca0a72c..b6b75ffa248d66cc4cb49339f193d486f05a6a4a 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -19,13 +19,13 @@ limitations under the License. #include "tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/conv_2d.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index f89d7ed0f45f919b17398de5d9449d12c08dd2f2..db0868fb2c43464a811b3d6dfcd96480ba2463ee 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -1,12 +1,14 @@ -# Files for using TFGAN framework. -package(default_visibility = ["//tensorflow:__subpackages__"]) +# Files for using TF-GAN framework. +load("//tensorflow:tensorflow.bzl", "py_test") + +package(default_visibility = [ + "//tensorflow:__subpackages__", +]) licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "py_test") - py_library( name = "gan", srcs = [ @@ -104,7 +106,9 @@ py_library( deps = [ ":gan_estimator", ":head", + ":latent_gan_estimator", ":stargan_estimator", + ":tpu_gan_estimator", "//tensorflow/python:util", ], ) @@ -128,6 +132,7 @@ py_library( ":clip_weights", ":conditioning_utils", ":random_tensor_pool", + ":spectral_normalization", ":virtual_batchnorm", "//tensorflow/python:util", ], @@ -141,16 +146,15 @@ py_library( "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:array_ops", "//tensorflow/python:clip_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", + "//tensorflow/python:gradients_impl", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:summary", "//tensorflow/python:tensor_util", "//tensorflow/python:variable_scope", - "//tensorflow/python/ops/distributions", "//tensorflow/python/ops/losses", - "//third_party/py/numpy", ], ) @@ -518,15 +522,19 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", + "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:metrics", "//tensorflow/python:parsing_ops", "//tensorflow/python:summary", + "//tensorflow/python:tensor_shape", "//tensorflow/python:training", "//tensorflow/python:training_util", "//tensorflow/python:variable_scope", - "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:numpy_io", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", "@six_archive//:six", @@ -562,28 +570,114 @@ py_test( deps = [ ":namedtuples", ":stargan_estimator", - ":tuple_losses", "//tensorflow/contrib/layers:layers_py", - "//tensorflow/contrib/learn", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:metrics", - "//tensorflow/python:parsing_ops", "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:training_util", "//tensorflow/python:variable_scope", - "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:numpy_io", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + "@six_archive//:six", + ], +) + +py_library( + name = "tpu_gan_estimator", + srcs = [ + "python/estimator/python/tpu_gan_estimator.py", + "python/estimator/python/tpu_gan_estimator_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":gan_estimator", + ":namedtuples", + ":train", + "//tensorflow/contrib/tpu:tpu_estimator", + "//tensorflow/contrib/tpu:tpu_lib", + "//tensorflow/contrib/training:training_py", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:util", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/ops/losses", + ], +) + +py_test( + name = "tpu_gan_estimator_test", + srcs = ["python/estimator/python/tpu_gan_estimator_test.py"], + shard_count = 11, + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":namedtuples", + ":tpu_gan_estimator", + ":tuple_losses", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/tpu:tpu_estimator", + "//tensorflow/contrib/tpu:tpu_lib", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:summary", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:training", + "//tensorflow/python:training_util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", "@six_archive//:six", ], ) +py_library( + name = "latent_gan_estimator", + srcs = [ + "python/estimator/python/latent_gan_estimator.py", + "python/estimator/python/latent_gan_estimator_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":train", + "//tensorflow/python:clip_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:random_ops", + "//tensorflow/python:summary", + "//tensorflow/python:training_util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:estimator_py", + ], +) + +py_test( + name = "latent_gan_estimator_test", + srcs = [ + "python/estimator/python/latent_gan_estimator_test.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":latent_gan_estimator", + "//tensorflow/python:array_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:run_config", + "//tensorflow/python/ops/losses", + ], +) + py_library( name = "sliced_wasserstein", srcs = [ @@ -618,3 +712,45 @@ py_test( "//third_party/py/numpy", ], ) + +py_library( + name = "spectral_normalization", + srcs = [ + "python/features/python/spectral_normalization.py", + "python/features/python/spectral_normalization_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:standard_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/keras:engine", + ], +) + +py_test( + name = "spectral_normalization_test", + srcs = ["python/features/python/spectral_normalization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":spectral_normalization", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/slim", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/keras:layers", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/gan/README.md b/tensorflow/contrib/gan/README.md index 9ab86329eaf0e6fd426aef1f552f4e27c2ad65de..4eac4e80cdacd779fdbedef19e4a654196f0caf1 100644 --- a/tensorflow/contrib/gan/README.md +++ b/tensorflow/contrib/gan/README.md @@ -1,14 +1,15 @@ -# TensorFlow-GAN (TFGAN) + +# TensorFlow-GAN (TF-GAN) -TFGAN is a lightweight library for training and evaluating Generative +TF-GAN is a lightweight library for training and evaluating Generative Adversarial Networks (GANs). This technique allows you to train a network (called the 'generator') to sample from a distribution, without having to explicitly model the distribution and without writing an explicit loss. For example, the generator could learn to draw samples from the distribution of natural images. For more details on this technique, see ['Generative Adversarial Networks'](https://arxiv.org/abs/1406.2661) by -Goodfellow et al. See [tensorflow/models](https://github.com/tensorflow/models/tree/master/research/gan/) for examples, and [this tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb) for an +Goodfellow et al. See [tensorflow/models](https://github.com/tensorflow/models/tree/master/research/gan/) for examples, and [this tutorial](http://https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb) for an introduction. #### Usage @@ -17,27 +18,27 @@ import tensorflow as tf tfgan = tf.contrib.gan ``` -## Why TFGAN? +## Why TF-GAN? * Easily train generator and discriminator networks with well-tested, flexible [library calls](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/train.py). You can -mix TFGAN, native TF, and other custom frameworks +mix TF-GAN, native TF, and other custom frameworks * Use already implemented [GAN losses and penalties](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/losses/python/losses_impl.py) (ex Wasserstein loss, gradient penalty, mutual information penalty, etc) * [Monitor and visualize](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/summaries_impl.py) GAN progress during training, and [evaluate](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py) them * Use already-implemented [tricks](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/) to stabilize and improve training * Develop based on examples of [common GAN setups](https://github.com/tensorflow/models/tree/master/research/gan/) -* Use the TFGAN-backed [GANEstimator](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py) to easily train a GAN model -* Improvements in TFGAN infrastructure will automatically benefit your TFGAN project +* Use the TF-GAN-backed [GANEstimator](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py) to easily train a GAN model +* Improvements in TF-GAN infrastructure will automatically benefit your TF-GAN project * Stay up-to-date with research as we add more algorithms -## What are the TFGAN components? +## What are the TF-GAN components? -TFGAN is composed of several parts which were design to exist independently. +TF-GAN is composed of several parts which were design to exist independently. These include the following main pieces (explained in detail below). * [core](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/train.py): provides the main infrastructure needed to train a GAN. Training occurs in four phases, and each phase can be completed by custom-code or by using a - TFGAN library call. + TF-GAN library call. * [features](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/): Many common GAN operations and normalization techniques are implemented for @@ -56,14 +57,14 @@ These include the following main pieces (explained in detail below). generative models. * [examples](https://github.com/tensorflow/models/tree/master/research/gan/) - and [tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb): See examples of how to use TFGAN to make - GAN training easier, or use the more complicated examples to jumpstart your + and [tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb): See examples of how to use TF-GAN to make + GAN training easier, or use the more complicated examples to jump-start your own project. These include unconditional and conditional GANs, InfoGANs, adversarial losses on existing networks, and image-to-image translation. ## Training a GAN model -Training in TFGAN typically consists of the following steps: +Training in TF-GAN typically consists of the following steps: 1. Specify the input to your networks. 1. Set up your generator and discriminator using a `GANModel`. @@ -71,12 +72,12 @@ Training in TFGAN typically consists of the following steps: 1. Create your train ops using a `GANTrainOps`. 1. Run your train ops. -At each stage, you can either use TFGAN's convenience functions, or you can +At each stage, you can either use TF-GAN's convenience functions, or you can perform the step manually for fine-grained control. We provide examples below. There are various types of GAN setups. For instance, you can train a generator to sample unconditionally from a learned distribution, or you can condition on -extra information such as a class label. TFGAN is compatible with many setups, +extra information such as a class label. TF-GAN is compatible with many setups, and we demonstrate a few below: ### Examples @@ -254,9 +255,9 @@ with variable_scope.variable_scope(dis_scope, reuse=True): discriminator_real_outputs = discriminator_fn(images) generator_variables = variables_lib.get_trainable_variables(gen_scope) discriminator_variables = variables_lib.get_trainable_variables(dis_scope) -# Depending on what TFGAN features you use, you don't always need to supply +# Depending on what TF-GAN features you use, you don't always need to supply # every `GANModel` field. At a minimum, you need to include the discriminator -# outputs and variables if you want to use TFGAN to construct losses. +# outputs and variables if you want to use TF-GAN to construct losses. gan_model = tfgan.GANModel( generator_inputs, generated_data, diff --git a/tensorflow/contrib/gan/__init__.py b/tensorflow/contrib/gan/__init__.py index f1946c7f925660eae3aaa650c437e03da1f33d6c..1e6000898f7b8a53ad3f6fa12deebd54bf3a57ff 100644 --- a/tensorflow/contrib/gan/__init__.py +++ b/tensorflow/contrib/gan/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN is a lightweight library for training and evaluating GANs. +"""TF-GAN is a lightweight library for training and evaluating GANs. In addition to providing the infrastructure for easily training and evaluating GANS, this library contains modules for a TFGAN-backed Estimator, @@ -24,7 +24,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# Collapse TFGAN into a tiered namespace. +# Collapse TF-GAN into a tiered namespace. from tensorflow.contrib.gan.python import estimator from tensorflow.contrib.gan.python import eval # pylint:disable=redefined-builtin from tensorflow.contrib.gan.python import features diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py index 99d38011ba677f03e198a431634fbb2ce349f912..430266555b723e6ca39dccffc1442dbef5d4a385 100644 --- a/tensorflow/contrib/gan/python/estimator/__init__.py +++ b/tensorflow/contrib/gan/python/estimator/__init__.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN estimator module. +"""TF-GAN estimator module. GANEstimator provides all the infrastructure support of a TensorFlow Estimator -with the feature support of TFGAN. +with the feature support of TF-GAN. """ from __future__ import absolute_import @@ -26,18 +26,25 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.gan.python.estimator.python import gan_estimator from tensorflow.contrib.gan.python.estimator.python import head +from tensorflow.contrib.gan.python.estimator.python import latent_gan_estimator from tensorflow.contrib.gan.python.estimator.python import stargan_estimator +from tensorflow.contrib.gan.python.estimator.python import tpu_gan_estimator from tensorflow.contrib.gan.python.estimator.python.gan_estimator import * from tensorflow.contrib.gan.python.estimator.python.head import * +from tensorflow.contrib.gan.python.estimator.python.latent_gan_estimator import * from tensorflow.contrib.gan.python.estimator.python.stargan_estimator import * +from tensorflow.contrib.gan.python.estimator.python.tpu_gan_estimator import * # pylint: enable=unused-import,wildcard-import from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = [ +_allowed_symbols = ([ 'gan_estimator', 'stargan_estimator', + 'tpu_gan_estimator', + 'latent_gan_estimator', 'head', -] + gan_estimator.__all__ + stargan_estimator.__all__ + head.__all__ +] + gan_estimator.__all__ + stargan_estimator.__all__ + head.__all__ + + tpu_gan_estimator.__all__ + latent_gan_estimator.__all__) remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index adb72228217892fffc10b0e2630edcd9d3e38a02..dd904611d1a3bb78de8316d5ed29ab0f800f29a9 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A TFGAN-backed GAN Estimator.""" +"""A TF-GAN-backed GAN Estimator.""" from __future__ import absolute_import from __future__ import division @@ -56,10 +56,10 @@ _summary_type_map = { class GANEstimator(estimator.Estimator): """An estimator for Generative Adversarial Networks (GANs). - This Estimator is backed by TFGAN. The network functions follow the TFGAN API - except for one exception: if either `generator_fn` or `discriminator_fn` have - an argument called `mode`, then the tf.Estimator mode is passed in for that - argument. This helps with operations like batch normalization, which have + This Estimator is backed by TF-GAN. The network functions follow the TF-GAN + API except for one exception: if either `generator_fn` or `discriminator_fn` + have an argument called `mode`, then the tf.Estimator mode is passed in for + that argument. This helps with operations like batch normalization, which have different train and evaluation behavior. Example: @@ -68,7 +68,7 @@ class GANEstimator(estimator.Estimator): import tensorflow as tf tfgan = tf.contrib.gan - # See TFGAN's `train.py` for a description of the generator and + # See TF-GAN's `train.py` for a description of the generator and # discriminator API. def generator_fn(generator_inputs): ... @@ -123,13 +123,13 @@ class GANEstimator(estimator.Estimator): to continue training a previously saved model. generator_fn: A python function that takes a Tensor, Tensor list, or Tensor dictionary as inputs and returns the outputs of the GAN - generator. See `TFGAN` for more details and examples. Additionally, if + generator. See `TF-GAN` for more details and examples. Additionally, if it has an argument called `mode`, the Estimator's `mode` will be passed in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch normalization. discriminator_fn: A python function that takes the output of `generator_fn` or real data in the GAN setup, and `generator_inputs`. - Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details + Outputs a Tensor in the range [-inf, inf]. See `TF-GAN` for more details and examples. generator_loss_fn: The loss function on the generator. Takes a `GANModel` tuple. diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 5a3d29cf0b3cb1bbe03cb5ba4f327caf46432b76..5b9c54e43a16adf457d5ed0e7e73dcd168ab0d67 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN's estimator.py.""" +"""Tests for TF-GAN's estimator.py.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index 1a0ee6dfc498eb6dc8c97411589d9e35bc352062..cbe990b476c3b17ce61e0826b17d10976fea43c7 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A TFGAN-backed GAN Estimator.""" +"""A TF-GAN-backed GAN Estimator.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py index 8205bc889dc01c8680e2139393d65723280cfbd0..5b50234a0e33cd297b176f142b358338966b6758 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN's head.py.""" +"""Tests for TF-GAN's head.py.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator.py b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..4e164e24168bb0cc5e9a7cc772081781ea088bb1 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""`tf.Learn` components for `Train Input Estimator`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.estimator.python import latent_gan_estimator_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.estimator.python.latent_gan_estimator_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = latent_gan_estimator_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..f5afc7731937ed1a82c8ebb5969b2687ffdd583b --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_impl.py @@ -0,0 +1,205 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements an estimator wrapper that allows training the input latent space. + +This file implements a latent gan estimator that wraps around a previously +trained GAN. The latent gan estimator trains a single variable z, representing +the hidden latent distribution that is the 'noise' input to the GAN. By training +z, the inpainting estimator can move around the latent z space towards +minimizing a specific loss function. + +The latent gan estimator has a few key differences from a normal estimator. + +First: the variables in the estimator should not be saved, as we are not +updating the original GAN and are only adding a new z variable that is meant +to be different for each run. In order to do distributed training using +train_and_evaluate, the Tensorflow RunConfig is expected to save checkpoints +by having either save_checkpoints_steps or save_checkpoints_secs saved. +To avoid this conflict, we purposely set the save_checkpoints_steps value in +the RunConfig to be one step more than the total number of steps that the +inpainter estimator will run. + +Second: we need to specify warm start settings, as we are reloading the +GAN model into a different graph (specifically, one with a new z variable). +The warm start settings defined below reload all GAN variables and ignore the +new z variable (and the optimizer). + +Usage: + + def _generator(net, mode): + ... + + def _discriminator(net, condition, mode): + ... + + def _loss(gan_model, features, labels, add_summaries): + ... + + def optimizer(): + ... + + params = {} + config = tf.estimator.RunConfig() + tmp_dir = path/to/output/storage + + estimator = latent_gan_estimator.get_latent_gan_estimator( + _generator, _discriminator, _loss, optimizer, params, config, tmp_dir) + + def input_fn(): + ... + + estimator.train(input_fn=input_fn) + +See latent_gan_estimator_test.py or tensorflow_models/gan/face_inpainting for +further examples. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +from tensorflow.contrib.gan.python import train as tfgan_train +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.summary import summary +from tensorflow.python.training import training_util + + +INPUT_NAME = 'new_var_z_input' # The name for the new z space input variable. +OPTIMIZER_NAME = 'latent_gan_optimizer' # The name for the new optimizer vars. + +__all__ = [ + 'get_latent_gan_estimator', +] + + +def _get_latent_gan_model_fn(generator_fn, discriminator_fn, loss_fn, + optimizer): + """Sets up a model function that wraps around a given GAN.""" + def model_fn(features, labels, mode, params): + """Model function defining an inpainting estimator.""" + batch_size = params['batch_size'] + z_shape = [batch_size] + params['z_shape'] + add_summaries = params['add_summaries'] + input_clip = params['input_clip'] + + z = variable_scope.get_variable( + name=INPUT_NAME, initializer=random_ops.truncated_normal(z_shape), + constraint=lambda x: clip_ops.clip_by_value(x, -input_clip, input_clip)) + + generator = functools.partial(generator_fn, mode=mode) + discriminator = functools.partial(discriminator_fn, mode=mode) + gan_model = tfgan_train.gan_model(generator_fn=generator, + discriminator_fn=discriminator, + real_data=labels, + generator_inputs=z, + check_shapes=False) + + loss = loss_fn(gan_model, features, labels, add_summaries) + + # Use a variable scope to make sure that estimator variables dont cause + # save/load problems when restoring from ckpts. + with variable_scope.variable_scope(OPTIMIZER_NAME): + opt = optimizer(learning_rate=params['learning_rate'], + **params['opt_kwargs']) + train_op = opt.minimize( + loss=loss, global_step=training_util.get_or_create_global_step(), + var_list=[z]) + + if add_summaries: + z_grads = gradients_impl.gradients(loss, z) + summary.scalar('z_loss/z_grads', clip_ops.global_norm(z_grads)) + summary.scalar('z_loss/loss', loss) + + return model_fn_lib.EstimatorSpec(mode=mode, + predictions=gan_model.generated_data, + loss=loss, + train_op=train_op) + return model_fn + + +def get_latent_gan_estimator(generator_fn, discriminator_fn, loss_fn, + optimizer, params, config, ckpt_dir, + warmstart_options=True): + """Gets an estimator that passes gradients to the input. + + This function takes in a generator and adds a trainable z variable that is + used as input to this generator_fn. The generator itself is treated as a black + box through which gradients can pass through without updating any weights. The + result is a trainable way to traverse the GAN latent space. The loss_fn is + used to actually train the z variable. The generator_fn and discriminator_fn + should be previously trained by the tfgan library (on reload, the variables + are expected to follow the tfgan format. It may be possible to use the + latent gan estimator with entirely custom GANs that do not use the tfgan + library as long as the appropriate variables are wired properly). + + Args: + generator_fn: a function defining a Tensorflow graph for a GAN generator. + The weights defined in this graph should already be defined in the given + checkpoint location. Should have 'mode' as an argument. + discriminator_fn: a function defining a Tensorflow graph for a GAN + discriminator. Should have 'mode' as an argument. + loss_fn: a function defining a Tensorflow graph for a GAN loss. Takes in a + GANModel tuple, features, labels, and add_summaries as inputs. + optimizer: a tf.Optimizer or a function that returns a tf.Optimizer with no + inputs. + params: An object containing the following parameters: + - batch_size: an int indicating the size of the training batch. + - z_shape: the desired shape of the input z values (not counting batch). + - learning_rate: a scalar or function defining a learning rate applied to + optimizer. + - input_clip: the amount to clip the x training variable by. + - add_summaries: whether or not to add summaries. + - opt_kwargs: optimizer kwargs. + config: tf.RunConfig. Should point model to output dir and should indicate + whether to save checkpoints (to avoid saving checkpoints, set + save_checkpoints_steps to a number larger than the number of train steps). + The model_dir field in the RunConfig should point to a directory WITHOUT + any saved checkpoints. + ckpt_dir: the directory where the model checkpoints live. The checkpoint is + used to warm start the underlying GAN. This should NOT be the same as + config.model_dir. + warmstart_options: boolean, None, or a WarmStartSettings object. If set to + True, uses a default WarmStartSettings object. If set to False or None, + does not use warm start. If using a custom WarmStartSettings object, make + sure that new variables are properly accounted for when reloading the + underlying GAN. Defaults to True. + Returns: + An estimator spec defining a GAN input training estimator. + """ + model_fn = _get_latent_gan_model_fn(generator_fn, discriminator_fn, + loss_fn, optimizer) + + if isinstance(warmstart_options, estimator.WarmStartSettings): + ws = warmstart_options + elif warmstart_options: + # Default WarmStart loads all variable names except INPUT_NAME and + # OPTIMIZER_NAME. + var_regex = '^(?!.*(%s|%s).*)' % (INPUT_NAME, OPTIMIZER_NAME) + ws = estimator.WarmStartSettings(ckpt_to_initialize_from=ckpt_dir, + vars_to_warm_start=var_regex) + else: + ws = None + + if 'opt_kwargs' not in params: + params['opt_kwargs'] = {} + + return estimator.Estimator(model_fn=model_fn, config=config, params=params, + warm_start_from=ws) diff --git a/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ac139e532e35f7aae6da0655103a7249fe3382d4 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_test.py @@ -0,0 +1,119 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for latent_gan_estimator. + +See g3.tp.tensorflow.contrib.gan.python.estimator.python.latent_gan_estimator. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile +import numpy as np +from tensorflow.contrib.gan.python.estimator.python import latent_gan_estimator +from tensorflow.python.estimator import run_config as run_config +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import test +from tensorflow.python.training import training + + +class TrainInputEstimatorTest(test.TestCase): + + def test_get_input_training_estimator(self): + """Integration test to make sure the input_training_estimator works.""" + + # Create dummy test input tensors. + true_features = np.reshape(np.random.uniform(size=100), (10, 10)) + true_labels = np.reshape(np.random.uniform(size=100), (5, 20)) + expected_z_output = [[1, -1], [-1, 1]] + + # Fill out required parameters randomly, includes optimizer kwargs. + params = { + 'batch_size': 2, + 'z_shape': [2], + 'learning_rate': 1.0, + 'input_clip': 1.0, + 'add_summaries': False, + 'opt_kwargs': { + 'beta1': 0.1 + } + } + + input_z_shape = [params['batch_size']] + params['z_shape'] + + # Create dummy model functions that represent an underlying GANEstimator and + # the input training wrapper. Make sure that everything is wired up + # correctly in the internals of each dummy function. + def _generator(net, mode): + """The generator function will get the newly created z variable.""" + del mode + self.assertSequenceEqual(net.shape, input_z_shape) + gen_dummy_var = variable_scope.get_variable( + name='generator_dummy_variable', + initializer=array_ops.ones(input_z_shape)) + return net * gen_dummy_var + + def _discriminator(net, condition, mode): + """The discriminator function will get either the z variable or labels.""" + del condition, mode + try: + self.assertSequenceEqual(net.shape, true_labels.shape) + except AssertionError: + self.assertSequenceEqual(net.shape, input_z_shape) + return net + + def _loss(gan_model, features, labels, _): + """Make sure that features and labels are passed in from input.""" + self.assertTrue(np.array_equal(features, true_features)) + self.assertTrue(np.array_equal(labels, true_labels)) + return losses.absolute_difference(expected_z_output, + gan_model.generated_data) + + optimizer = training.AdamOptimizer + + # We are not loading checkpoints, so set the corresponding directory to a + # dummy directories. + tmp_dir = tempfile.mkdtemp() + config = run_config.RunConfig(model_dir=tmp_dir, + save_summary_steps=None, + save_checkpoints_steps=1, + save_checkpoints_secs=None) + + # Get the estimator. Disable warm start so that there is no attempted + # checkpoint reloading. + estimator = latent_gan_estimator.get_latent_gan_estimator( + _generator, _discriminator, _loss, optimizer, params, config, tmp_dir, + warmstart_options=None) + + # Train for a few steps. + def dummy_input(): + return true_features, true_labels + estimator.train(input_fn=dummy_input, steps=10) + + # Make sure the generator variables did not change, but the z variables did + # change. + self.assertTrue(np.array_equal( + estimator.get_variable_value('Generator/generator_dummy_variable'), + np.ones(input_z_shape))) + self.assertTrue(np.array_equal( + estimator.get_variable_value('new_var_z_input'), + expected_z_output)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py index f60e16bc04662b33bc0bb22b5acc8c7fcc7a03ba..2a485e7d47ff10cf34c1b44f8dcc6b1f33c9a05f 100644 --- a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A TFGAN-backed StarGAN Estimator.""" +"""A TF-GAN-backed StarGAN Estimator.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py index 2ec7938c7c4051842c7e982b54c1213b6e841b79..c00ff4399748a77f88d9753df7592bf3859d754e 100644 --- a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN's stargan_estimator.py.""" +"""Tests for TF-GAN's stargan_estimator.py.""" from __future__ import absolute_import from __future__ import division @@ -80,7 +80,7 @@ class StarGetGANModelTest(test.TestCase, parameterized.TestCase): self.assertEqual(input_data, gan_model.input_data) self.assertIsNotNone(gan_model.generated_data) self.assertIsNotNone(gan_model.generated_data_domain_target) - self.assertEqual(1, len(gan_model.generator_variables)) + self.assertLen(gan_model.generator_variables, 1) self.assertIsNotNone(gan_model.generator_scope) self.assertIsNotNone(gan_model.generator_fn) if mode == model_fn_lib.ModeKeys.PREDICT: @@ -109,7 +109,7 @@ class StarGetGANModelTest(test.TestCase, parameterized.TestCase): gan_model.discriminator_input_data_domain_predication) self.assertIsNotNone( gan_model.discriminator_generated_data_domain_predication) - self.assertEqual(2, len(gan_model.discriminator_variables)) # 1 FC layer + self.assertLen(gan_model.discriminator_variables, 2) # 1 FC layer self.assertIsNotNone(gan_model.discriminator_scope) self.assertIsNotNone(gan_model.discriminator_fn) @@ -163,6 +163,7 @@ class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase): @classmethod def setUpClass(cls): + super(GetEstimatorSpecTest, cls).setUpClass() cls._generator_optimizer = training.GradientDescentOptimizer(1.0) cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0) diff --git a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator.py b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..deb381f7be3f9545ed918813ee55aede946f22d4 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""`tf.Learn` components for `TPUGANEstimator`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.estimator.python import tpu_gan_estimator_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.estimator.python.tpu_gan_estimator_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = tpu_gan_estimator_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2a22c78a304c7cc66ef069a235483e9279b3b2 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py @@ -0,0 +1,423 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A TF-GAN-backed GAN Estimator that works on TPU.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples +from tensorflow.contrib.gan.python import train as tfgan_train +from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl as gan_estimator_lib +from tensorflow.contrib.tpu.python.tpu import tpu_estimator +from tensorflow.contrib.tpu.python.tpu import tpu_optimizer +from tensorflow.contrib.training.python.training import training +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops.losses import losses + +__all__ = [ + 'TPUGANEstimator', +] + + +class TPUGANEstimator(tpu_estimator.TPUEstimator): + """An estimator for Generative Adversarial Networks (GANs) on TPU. + + This Estimator is backed by TFGAN. It is similar to `tfgan.GANEstimator`, + but works on TPU. + + Example: + + ```python + import tensorflow as tf + tfgan = tf.contrib.gan + + # See TFGAN's `train.py` for a description of the generator and + # discriminator API. + def generator_fn(generator_inputs): + ... + return generated_data + + def discriminator_fn(data, conditioning): + ... + return logits + + # Create GAN estimator. + config = tpu_config.RunConfig(model_dir='/my/dir') + gan_estimator = tfgan.estimator.TPUGANEstimator( + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + generator_loss_fn=tfgan.losses.wasserstein_generator_loss, + discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, + generator_optimizer=tf.train.AdamOptimizer(0.1, 0.5), + discriminator_optimizer=tf.train.AdamOptimizer(0.1, 0.5), + train_batch_size=4, + config=config) + + # Train estimator. + gan_estimator.train(train_input_fn, train_steps) + + # Evaluate resulting estimator. + gan_estimator.evaluate(eval_input_fn, eval_steps) + + # Generate samples from generator. + predictions = np.array([ + x['generated_data'] for x in gan_estimator.predict(predict_input_fn)]) + ``` + """ + + def __init__(self, + # Arguments to construct the `model_fn`. + generator_fn=None, + discriminator_fn=None, + generator_loss_fn=None, + discriminator_loss_fn=None, + generator_optimizer=None, + discriminator_optimizer=None, + get_eval_metric_ops_fn=None, + add_summaries=None, + joint_train=False, + gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1), + # TPUEstimator options. + model_dir=None, + config=None, + params=None, + use_tpu=True, + train_batch_size=None, + eval_batch_size=None, + predict_batch_size=None, + batch_axis=None, + eval_on_tpu=True, + export_to_tpu=True, + warm_start_from=None): + """Initializes a TPUGANEstimator instance. + + Args: + generator_fn: A python function that takes a Tensor, Tensor list, or + Tensor dictionary as inputs and returns the outputs of the GAN + generator. See `TFGAN` for more details and examples. Additionally, if + it has an argument called `mode`, the Estimator's `mode` will be passed + in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch + normalization. + discriminator_fn: A python function that takes the output of + `generator_fn` or real data in the GAN setup, and `generator_inputs`. + Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details + and examples. + generator_loss_fn: The loss function on the generator. Takes a `GANModel` + tuple. + discriminator_loss_fn: The loss function on the discriminator. Takes a + `GANModel` tuple. + generator_optimizer: The optimizer for generator updates, or a function + that takes no arguments and returns an optimizer. This function will + be called when the default graph is the `GANEstimator`'s graph, so + utilities like `tf.contrib.framework.get_or_create_global_step` will + work. + discriminator_optimizer: Same as `generator_optimizer`, but for the + discriminator updates. + get_eval_metric_ops_fn: A function that takes a list of arguments and + returns a dict of metric results keyed by name. The output of this + function is passed into `tf.estimator.EstimatorSpec` during evaluation. + The arguments must be: + * generator_inputs + * generated_data + * real_data + * discriminator_real_outputs + * discriminator_gen_outputs + add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. + This is ignored for jobs that run on TPU, such as the train job if + `use_tpu` is `True` or the eval job if `eval_on_tpu` is `True`. + joint_train: A Python boolean. If `True`, jointly train the generator and + the discriminator. If `False`, sequentially train them. See `train.py` + in TFGAN for more details on the differences between the two GAN + training methods. + gan_train_steps: A `tfgan.GANTrainSteps` named tuple describing the ratio + of generator to discriminator steps. For now, only supports 1:1 + training. + model_dir: Same as `TPUEstimator`: Directory to save model parameters, + graph and etc. This can also be used to load checkpoints from the + directory into a estimator to continue training a previously saved + model. If `None`, the model_dir in `config` will be used if set. If both + are set, they must be same. If both are `None`, a temporary directory + will be used. + config: Same as `TPUEstimator`: An `tpu_config.RunConfig` configuration + object. Cannot be `None`. + params: Same as `TPUEstimator`: An optional `dict` of hyper parameters + that will be passed into `input_fn` and `model_fn`. Keys are names of + parameters, values are basic python types. There are reserved keys for + `TPUEstimator`, including 'batch_size'. + use_tpu: Same as `TPUEstimator`: A bool indicating whether TPU support is + enabled. Currently, TPU training and evaluation respect this bit, but + eval_on_tpu can override execution of eval. See below. Predict still + happens on CPU. + train_batch_size: Same as `TPUEstimator`: An int representing the global + training batch size. TPUEstimator transforms this global batch size to a + per-shard batch size, as params['batch_size'], when calling `input_fn` + and `model_fn`. Cannot be `None` if `use_tpu` is `True`. Must be + divisible by total number of replicas. + eval_batch_size: Same as `TPUEstimator`: An int representing evaluation + batch size. Must be divisible by total number of replicas. + predict_batch_size: Same as `TPUEstimator`: An int representing the + prediction batch size. Must be divisible by total number of replicas. + batch_axis: Same as `TPUEstimator`: A python tuple of int values + describing how each tensor produced by the Estimator `input_fn` should + be split across the TPU compute shards. For example, if your input_fn + produced (images, labels) where the images tensor is in `HWCN` format, + your shard dimensions would be [3, 0], where 3 corresponds to the `N` + dimension of your images Tensor, and 0 corresponds to the dimension + along which to split the labels to match up with the corresponding + images. If None is supplied, and per_host_input_for_training is True, + batches will be sharded based on the major dimension. If + tpu_config.per_host_input_for_training is False or `PER_HOST_V2`, + batch_axis is ignored. + eval_on_tpu: Same as `TPUEstimator`: If False, evaluation runs on CPU or + GPU. In this case, the model_fn must return `EstimatorSpec` when called + with `mode` as `EVAL`. + export_to_tpu: Same as `TPUEstimator`: If True, `export_savedmodel()` + exports a metagraph for serving on TPU besides the one on CPU. + warm_start_from: Same as `TPUEstimator`: Optional string filepath to a + checkpoint or SavedModel to warm-start from, or a + `tf.estimator.WarmStartSettings` object to fully configure + warm-starting. If the string filepath is provided instead of a + `WarmStartSettings`, then all variables are warm-started, and it is + assumed that vocabularies and Tensor names are unchanged. + + Raises: + ValueError: If loss functions aren't callable. + ValueError: If `gan_train_steps` isn't a `tfgan_tuples.GANTrainSteps` + tuple. + ValueError: If `gan_train_steps` isn't 1:1 training. + """ + if not callable(generator_loss_fn): + raise ValueError('generator_loss_fn must be callable.') + if not callable(discriminator_loss_fn): + raise ValueError('discriminator_loss_fn must be callable.') + if not isinstance(gan_train_steps, tfgan_tuples.GANTrainSteps): + raise ValueError( + '`gan_train_steps` must be `tfgan_tuples.GANTrainSteps`. Instead, ' + 'was type: %s' % type(gan_train_steps)) + if (gan_train_steps.generator_train_steps != 1 or + gan_train_steps.discriminator_train_steps != 1): + raise ValueError('Estimator currently only supports 1:1 training.') + + if use_tpu: + generator_optimizer = _maybe_make_cross_shard_optimizer( + generator_optimizer) + discriminator_optimizer = _maybe_make_cross_shard_optimizer( + discriminator_optimizer) + + def _model_fn(features, labels, mode, params): + """GANEstimator model function.""" + del params # unused + if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, + model_fn_lib.ModeKeys.PREDICT]: + raise ValueError('Mode not recognized: %s' % mode) + real_data = labels # rename inputs for clarity + generator_inputs = features # rename inputs for clarity + + # Make GANModel, which encapsulates the GAN model architectures. + # TODO(joelshor): Switch TF-GAN over to TPU-compatible summaries, then + # remove `add_summaries` logic below. + is_on_tpu = _is_on_tpu(mode, use_tpu, eval_on_tpu) + gan_model = gan_estimator_lib._get_gan_model( # pylint:disable=protected-access + mode, generator_fn, discriminator_fn, real_data, generator_inputs, + add_summaries=None if is_on_tpu else add_summaries) + + # Make the TPUEstimatorSpec, which incorporates the GANModel, losses, eval + # metrics, and optimizers (if required). + estimator_spec = _get_estimator_spec( + mode, gan_model, generator_loss_fn, discriminator_loss_fn, + get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, + joint_train, is_on_tpu, gan_train_steps) + assert isinstance(estimator_spec, tpu_estimator.TPUEstimatorSpec) + return estimator_spec + + super(TPUGANEstimator, self).__init__( + model_fn=_model_fn, + model_dir=model_dir, + config=config, + params=params, + use_tpu=use_tpu, + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + predict_batch_size=predict_batch_size, + batch_axis=batch_axis, + eval_on_tpu=eval_on_tpu, + export_to_tpu=export_to_tpu, + warm_start_from=warm_start_from) + + +def _is_on_tpu(mode, use_tpu, eval_on_tpu): + if mode == model_fn_lib.ModeKeys.TRAIN: + return use_tpu + elif mode == model_fn_lib.ModeKeys.EVAL: + return eval_on_tpu + else: + return False + + +def _get_estimator_spec( + mode, gan_model, generator_loss_fn, discriminator_loss_fn, + get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, + joint_train, is_on_tpu, gan_train_steps): + """Get the TPUEstimatorSpec for the current mode.""" + if mode == model_fn_lib.ModeKeys.PREDICT: + estimator_spec = tpu_estimator.TPUEstimatorSpec( + mode=mode, predictions={'generated_data': gan_model.generated_data}) + elif mode == model_fn_lib.ModeKeys.EVAL: + gan_loss = tfgan_tuples.GANLoss( + generator_loss=generator_loss_fn( + gan_model, add_summaries=not is_on_tpu), + discriminator_loss=discriminator_loss_fn( + gan_model, add_summaries=not is_on_tpu)) + # Eval losses for metrics must preserve batch dimension. + gan_loss_no_reduction = tfgan_tuples.GANLoss( + generator_loss=generator_loss_fn( + gan_model, add_summaries=False, reduction=losses.Reduction.NONE), + discriminator_loss=discriminator_loss_fn( + gan_model, add_summaries=False, reduction=losses.Reduction.NONE)) + estimator_spec = _get_eval_estimator_spec( + gan_model, gan_loss, gan_loss_no_reduction, get_eval_metric_ops_fn) + else: # model_fn_lib.ModeKeys.TRAIN: + gan_loss = tfgan_tuples.GANLoss( + generator_loss=generator_loss_fn( + gan_model, add_summaries=not is_on_tpu), + discriminator_loss=discriminator_loss_fn( + gan_model, add_summaries=not is_on_tpu)) + + # Construct optimizers if arguments were callable. For TPUs, they must be + # `CrossShardOptimizer`. + g_callable = callable(generator_optimizer) + gopt = generator_optimizer() if g_callable else generator_optimizer + d_callable = callable(discriminator_optimizer) + dopt = discriminator_optimizer() if d_callable else discriminator_optimizer + + estimator_spec = _get_train_estimator_spec( + gan_model, gan_loss, gopt, dopt, joint_train, gan_train_steps) + + return estimator_spec + + +def _get_eval_estimator_spec(gan_model, gan_loss, gan_loss_no_reduction, + get_eval_metric_ops_fn): + """Return an TPUEstimatorSpec for the eval case.""" + # Make the metric function and tensor names. + if get_eval_metric_ops_fn is not None: + def metric_fn( + generator_inputs, generated_data, real_data, discriminator_real_outputs, + discriminator_gen_outputs, generator_loss, discriminator_loss): + """`metric_fn` used in TPUEstimator to calculate metrics.""" + eval_metric_ops = { + 'generator_loss': metrics_lib.mean(generator_loss), + 'discriminator_loss': metrics_lib.mean(discriminator_loss), + } + custom_eval_metric_ops = get_eval_metric_ops_fn( + generator_inputs, generated_data, real_data, + discriminator_real_outputs, discriminator_gen_outputs) + if not isinstance(custom_eval_metric_ops, dict): + raise TypeError('`get_eval_metric_ops_fn` must return a dict, ' + 'received: {}'.format(custom_eval_metric_ops)) + eval_metric_ops.update(custom_eval_metric_ops) + return eval_metric_ops + tensors = { + 'generator_loss': gan_loss_no_reduction.generator_loss, + 'discriminator_loss': gan_loss_no_reduction.discriminator_loss, + 'generator_inputs': gan_model.generator_inputs, + 'generated_data': gan_model.generated_data, + 'real_data': gan_model.real_data, + 'discriminator_real_outputs': gan_model.discriminator_real_outputs, + 'discriminator_gen_outputs': gan_model.discriminator_gen_outputs, + } + else: + def metric_fn(generator_loss, discriminator_loss): + return { + 'generator_loss': metrics_lib.mean(generator_loss), + 'discriminator_loss': metrics_lib.mean(discriminator_loss), + } + tensors = { + 'generator_loss': gan_loss_no_reduction.generator_loss, + 'discriminator_loss': gan_loss_no_reduction.discriminator_loss, + } + + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + return tpu_estimator.TPUEstimatorSpec( + mode=model_fn_lib.ModeKeys.EVAL, + predictions=gan_model.generated_data, + loss=scalar_loss, + eval_metrics=(metric_fn, tensors)) + + +def _get_train_estimator_spec( + gan_model, gan_loss, generator_optimizer, discriminator_optimizer, + joint_train, gan_train_steps): + """Return a TPUEstimatorSpec for the train case.""" + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + + # Get generator and discriminator update ops. We split them so that update + # ops aren't accidentally run multiple times. For now, throw an error if + # there are update ops that aren't associated with either the generator or + # the discriminator. Might modify the `kwargs` dictionary. + gen_update_ops, dis_update_ops = tfgan_train._get_update_ops( # pylint:disable=protected-access + {}, gan_model.generator_scope.name, gan_model.discriminator_scope.name) + + def gen_train_op(): + with ops.name_scope('generator_train'): + return training.create_train_op( + total_loss=gan_loss.generator_loss, + optimizer=generator_optimizer, + variables_to_train=gan_model.generator_variables, + update_ops=gen_update_ops) + def dis_train_op(): + with ops.name_scope('discriminator_train'): + return training.create_train_op( + total_loss=gan_loss.discriminator_loss, + optimizer=discriminator_optimizer, + variables_to_train=gan_model.discriminator_variables, + update_ops=dis_update_ops) + + # Either optimize the generator and discriminator sequentially or jointly. + tpu_train_op = _combine_train_ops(gen_train_op, dis_train_op, joint_train, + gan_train_steps) + + return tpu_estimator.TPUEstimatorSpec( + loss=scalar_loss, + mode=model_fn_lib.ModeKeys.TRAIN, + train_op=tpu_train_op) + + +# TODO(joelshor): Add support for multiple D / G steps. +def _combine_train_ops(gen_train_op, dis_train_op, joint_train, + gan_train_steps): + """Combine generator and discriminator train ops into a single op.""" + del gan_train_steps + if joint_train: + tpu_train_op = control_flow_ops.group(gen_train_op(), dis_train_op(), + name='joint_train') + else: + with ops.control_dependencies([dis_train_op()]): + tpu_train_op = gen_train_op() + + return tpu_train_op + + +def _maybe_make_cross_shard_optimizer(opt): + if callable(opt): + if not isinstance(opt(), tpu_optimizer.CrossShardOptimizer): + return lambda: tpu_optimizer.CrossShardOptimizer(opt()) + elif not isinstance(opt, tpu_optimizer.CrossShardOptimizer): + return tpu_optimizer.CrossShardOptimizer(opt) + return opt diff --git a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0d9e6489bdd1d89cc49bfedc2eed784999c31d2b --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py @@ -0,0 +1,319 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TF-GAN's TPU Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile + +from absl.testing import parameterized +import numpy as np +import six + +from tensorflow.contrib import layers +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples +from tensorflow.contrib.gan.python.estimator.python import tpu_gan_estimator_impl as estimator +from tensorflow.contrib.gan.python.losses.python import tuple_losses as losses +from tensorflow.contrib.tpu.python.tpu import tpu_config +from tensorflow.contrib.tpu.python.tpu import tpu_estimator +from tensorflow.contrib.tpu.python.tpu import tpu_optimizer +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.estimator import WarmStartSettings +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework.errors_impl import NotFoundError +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import flags +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import learning_rate_decay +from tensorflow.python.training import training +from tensorflow.python.training import training_util + +FLAGS = flags.FLAGS + +flags.DEFINE_bool('use_tpu', False, 'Whether to run test on TPU or not.') + + +def generator_fn(noise, mode): + del mode + return layers.fully_connected(noise, tensor_shape.dimension_value( + noise.shape[1])) + + +def discriminator_fn(data, unused_conditioning, mode): + del unused_conditioning, mode + return layers.fully_connected(data, 1) + + +def get_dummy_gan_model(): + # TODO(joelshor): Find a better way of creating a variable scope. + with variable_scope.variable_scope('generator') as gen_scope: + gen_var = variable_scope.get_variable('dummy_var', initializer=0.0) + with variable_scope.variable_scope('discriminator') as dis_scope: + dis_var = variable_scope.get_variable('dummy_var', initializer=0.0) + return tfgan_tuples.GANModel( + generator_inputs=None, + generated_data=array_ops.ones([3, 4]), + generator_variables=[gen_var], + generator_scope=gen_scope, + generator_fn=None, + real_data=array_ops.zeros([3, 4]), + discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var, + discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var, + discriminator_variables=[dis_var], + discriminator_scope=dis_scope, + discriminator_fn=None) + + +def get_metrics(generator_inputs, generated_data, real_data, + discriminator_real_outputs, discriminator_gen_outputs): + del generator_inputs, discriminator_real_outputs, discriminator_gen_outputs + return { + 'mse_custom_metric': metrics_lib.mean_squared_error( + real_data, generated_data) + } + + +class GetTPUEstimatorSpecTest(test.TestCase, parameterized.TestCase): + """Tests that the EstimatorSpec is constructed appropriately.""" + + @classmethod + def setUpClass(cls): + super(GetTPUEstimatorSpecTest, cls).setUpClass() + cls._generator_optimizer = tpu_optimizer.CrossShardOptimizer( + training.GradientDescentOptimizer(1.0)) + cls._discriminator_optimizer = tpu_optimizer.CrossShardOptimizer( + training.GradientDescentOptimizer(1.0)) + + @parameterized.named_parameters( + ('joint_train', model_fn_lib.ModeKeys.TRAIN, True), + ('train_sequential', model_fn_lib.ModeKeys.TRAIN, False), + ('eval', model_fn_lib.ModeKeys.EVAL, None), + ('predict', model_fn_lib.ModeKeys.PREDICT, None)) + def test_get_estimator_spec(self, mode, joint_train): + with ops.Graph().as_default(): + self._gan_model = get_dummy_gan_model() + spec = estimator._get_estimator_spec( + mode, + self._gan_model, + generator_loss_fn=losses.wasserstein_generator_loss, + discriminator_loss_fn=losses.wasserstein_discriminator_loss, + get_eval_metric_ops_fn=get_metrics, + generator_optimizer=self._generator_optimizer, + discriminator_optimizer=self._discriminator_optimizer, + joint_train=joint_train, + is_on_tpu=FLAGS.use_tpu, + gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1)) + + self.assertIsInstance(spec, tpu_estimator.TPUEstimatorSpec) + self.assertEqual(mode, spec.mode) + if mode == model_fn_lib.ModeKeys.PREDICT: + self.assertEqual({'generated_data': self._gan_model.generated_data}, + spec.predictions) + elif mode == model_fn_lib.ModeKeys.TRAIN: + self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar + self.assertIsNotNone(spec.train_op) + self.assertIsNotNone(spec.training_hooks) + elif mode == model_fn_lib.ModeKeys.EVAL: + self.assertEqual(self._gan_model.generated_data, spec.predictions) + self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar + self.assertIsNotNone(spec.eval_metrics) + + +class TPUGANEstimatorIntegrationTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + super(TPUGANEstimatorIntegrationTest, self).setUp() + self._model_dir = tempfile.mkdtemp() + self._config = tpu_config.RunConfig(model_dir=self._model_dir) + + def tearDown(self): + super(TPUGANEstimatorIntegrationTest, self).tearDown() + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow( + self, train_input_fn, eval_input_fn, predict_input_fn, prediction_size, + lr_decay=False, joint_train=True): + def make_opt(): + gstep = training_util.get_or_create_global_step() + lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) + return training.GradientDescentOptimizer(lr) + + gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) + dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) + est = estimator.TPUGANEstimator( + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + generator_loss_fn=losses.wasserstein_generator_loss, + discriminator_loss_fn=losses.wasserstein_discriminator_loss, + generator_optimizer=gopt, + discriminator_optimizer=dopt, + joint_train=joint_train, + get_eval_metric_ops_fn=get_metrics, + train_batch_size=4, + eval_batch_size=10, + predict_batch_size=8, + use_tpu=FLAGS.use_tpu, + config=self._config) + + # Train. + num_steps_train = 10 + est.train(train_input_fn, steps=num_steps_train) + + # Evaluate. + num_steps_eval = 2 + scores = est.evaluate(eval_input_fn, steps=num_steps_eval) + self.assertIn(ops.GraphKeys.GLOBAL_STEP, six.iterkeys(scores)) + self.assertIn('loss', six.iterkeys(scores)) + self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], + scores['loss']) + self.assertIn('mse_custom_metric', six.iterkeys(scores)) + + # Predict. + predictions = np.array([x['generated_data'] for x in + est.predict(predict_input_fn)]) + self.assertAllEqual(prediction_size, predictions.shape) + + @parameterized.named_parameters( + ('joint_train', True, False, False), + ('train_sequential', False, False, False), + ('lr_decay', False, True, False), + ('train_sequential_ds', False, False, True)) + def test_numpy_input_fn(self, joint_train, lr_decay, return_ds): + """Tests complete flow with numpy_input_fn.""" + input_dim = 4 + def train_input_fn(params): + data = np.zeros([input_dim], dtype=np.float32) + ds = (dataset_ops.Dataset + .from_tensors((data, data)) + .repeat() + .batch(params['batch_size'], drop_remainder=True)) + if return_ds: + return ds + else: + x, y = ds.make_one_shot_iterator().get_next() + return x, y + def eval_input_fn(params): + data = np.zeros([input_dim], dtype=np.float32) + ds = (dataset_ops.Dataset + .from_tensors((data, data)) + .repeat() + .batch(params['batch_size'], drop_remainder=True)) + if return_ds: + return ds + else: + x, y = ds.make_one_shot_iterator().get_next() + return x, y + predict_size = 10 + def predict_input_fn(params): + del params # unused + data = np.zeros([input_dim], dtype=np.float32) + ds = (dataset_ops.Dataset + .from_tensors(data) + .repeat(predict_size) + .batch(1, drop_remainder=True)) + return ds + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + prediction_size=[predict_size, input_dim], + lr_decay=lr_decay, + joint_train=joint_train) + + +class TPUGANEstimatorWarmStartTest(test.TestCase): + + def setUp(self): + self._model_dir = self.get_temp_dir() + self._config = tpu_config.RunConfig(model_dir=self._model_dir) + self.new_variable_name = 'new_var' + self.new_variable_value = [1.0, 2.0, 3.0] + + def tearDown(self): + writer_cache.FileWriterCache.clear() + + def _test_warm_start(self, warm_start_from=None): + """Tests whether WarmStartSettings work as intended.""" + def generator_with_new_variable(noise_dict, mode): + variable_scope.get_variable(name=self.new_variable_name, + initializer=self.new_variable_value, + trainable=True) + return generator_fn(noise_dict, mode) + + est = estimator.TPUGANEstimator( + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + generator_loss_fn=losses.wasserstein_generator_loss, + discriminator_loss_fn=losses.wasserstein_discriminator_loss, + generator_optimizer=training.GradientDescentOptimizer(1.0), + discriminator_optimizer=training.GradientDescentOptimizer(1.0), + train_batch_size=4, + use_tpu=FLAGS.use_tpu, + config=self._config) + + def train_input_fn(params): + data = np.zeros([params['batch_size'], 4], dtype=np.float32) + return data, data + + est.train(train_input_fn, steps=1) + + est_warm = estimator.TPUGANEstimator( + generator_fn=generator_with_new_variable, + discriminator_fn=discriminator_fn, + generator_loss_fn=losses.wasserstein_generator_loss, + discriminator_loss_fn=losses.wasserstein_discriminator_loss, + generator_optimizer=training.GradientDescentOptimizer(1.0), + discriminator_optimizer=training.GradientDescentOptimizer(1.0), + config=tpu_config.RunConfig( + model_dir=None if warm_start_from else self._model_dir), + train_batch_size=4, + use_tpu=FLAGS.use_tpu, + warm_start_from=warm_start_from) + + est_warm.train(train_input_fn, steps=1) + + return est_warm + + def test_warm_start_error(self): + """Test if exception when reloading different estimators.""" + with self.assertRaises(NotFoundError): + self._test_warm_start() + + def test_warm_start_success(self): + """Test if GANEstimator allows explicit warm start variable assignment.""" + # Regex matches all variable names in ckpt except for new_var. + var_regex = '^(?!.*%s.*)' % self.new_variable_name + warmstart = WarmStartSettings(ckpt_to_initialize_from=self._model_dir, + vars_to_warm_start=var_regex) + est_warm = self._test_warm_start(warm_start_from=warmstart) + full_variable_name = 'Generator/%s' % self.new_variable_name + self.assertIn(full_variable_name, est_warm.get_variable_names()) + equal_vals = np.array_equal(est_warm.get_variable_value(full_variable_name), + self.new_variable_value) + self.assertTrue(equal_vals) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/eval/__init__.py b/tensorflow/contrib/gan/python/eval/__init__.py index f86b8513053a45f9830411f7df2c32d1f36a97b2..92e9abf8a35de1999eb800e169f32220fe47f8cd 100644 --- a/tensorflow/contrib/gan/python/eval/__init__.py +++ b/tensorflow/contrib/gan/python/eval/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN evaluation module. +"""TF-GAN evaluation module. This module supports techniques such as Inception Score, Frechet Inception distance, and Sliced Wasserstein distance. diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py index 1c872626a957279132772ae27df7a66a2564e9a5..a52e899114b62cb29752f72aa59f142f4a428aa1 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Model evaluation tools for TFGAN.""" +"""Model evaluation tools for TF-GAN.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index a71ee53311c1c057a5b41be0331bf56ce1a82f74..31f0d34ed68a6adc25cca102236079d0f66615cb 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Model evaluation tools for TFGAN. +"""Model evaluation tools for TF-GAN. These methods come from https://arxiv.org/abs/1606.03498, https://arxiv.org/abs/1706.08500, and https://arxiv.org/abs/1801.01401. @@ -387,7 +387,7 @@ def classifier_score_from_logits(logits): # Use maximum precision for best results. logits_dtype = logits.dtype if logits_dtype != dtypes.float64: - logits = math_ops.to_double(logits) + logits = math_ops.cast(logits, dtypes.float64) p = nn_ops.softmax(logits) q = math_ops.reduce_mean(p, axis=0) @@ -562,8 +562,8 @@ def mean_only_frechet_classifier_distance_from_activations( activations_dtype = real_activations.dtype if activations_dtype != dtypes.float64: - real_activations = math_ops.to_double(real_activations) - generated_activations = math_ops.to_double(generated_activations) + real_activations = math_ops.cast(real_activations, dtypes.float64) + generated_activations = math_ops.cast(generated_activations, dtypes.float64) # Compute means of activations. m = math_ops.reduce_mean(real_activations, 0) @@ -623,8 +623,8 @@ def diagonal_only_frechet_classifier_distance_from_activations( activations_dtype = real_activations.dtype if activations_dtype != dtypes.float64: - real_activations = math_ops.to_double(real_activations) - generated_activations = math_ops.to_double(generated_activations) + real_activations = math_ops.cast(real_activations, dtypes.float64) + generated_activations = math_ops.cast(generated_activations, dtypes.float64) # Compute mean and covariance matrices of activations. m, var = nn_impl.moments(real_activations, axes=[0]) @@ -698,15 +698,16 @@ def frechet_classifier_distance_from_activations(real_activations, activations_dtype = real_activations.dtype if activations_dtype != dtypes.float64: - real_activations = math_ops.to_double(real_activations) - generated_activations = math_ops.to_double(generated_activations) + real_activations = math_ops.cast(real_activations, dtypes.float64) + generated_activations = math_ops.cast(generated_activations, dtypes.float64) # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_activations, 0) m_w = math_ops.reduce_mean(generated_activations, 0) - num_examples_real = math_ops.to_double(array_ops.shape(real_activations)[0]) - num_examples_generated = math_ops.to_double( - array_ops.shape(generated_activations)[0]) + num_examples_real = math_ops.cast( + array_ops.shape(real_activations)[0], dtypes.float64) + num_examples_generated = math_ops.cast( + array_ops.shape(generated_activations)[0], dtypes.float64) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T real_centered = real_activations - m @@ -794,9 +795,9 @@ def kernel_classifier_distance(real_images, on a classifier. num_classifier_batches: Number of batches to split images in to in order to efficiently run them through the classifier network. - max_estimator_block_size: integer, default 1024. The distance estimator - splits samples into blocks for computational efficiency. Larger values are - more computationally expensive but decrease the variance of the distance + max_block_size: integer, default 1024. The distance estimator splits samples + into blocks for computational efficiency. Larger values are more + computationally expensive but decrease the variance of the distance estimate. dtype: if not None, coerce activations to this dtype before computations. @@ -871,9 +872,9 @@ def kernel_classifier_distance_and_std(real_images, on a classifier. num_classifier_batches: Number of batches to split images in to in order to efficiently run them through the classifier network. - max_estimator_block_size: integer, default 1024. The distance estimator - splits samples into blocks for computational efficiency. Larger values are - more computationally expensive but decrease the variance of the distance + max_block_size: integer, default 1024. The distance estimator splits samples + into blocks for computational efficiency. Larger values are more + computationally expensive but decrease the variance of the distance estimate. Having a smaller block size also gives a better estimate of the standard error. dtype: if not None, coerce activations to this dtype before computations. @@ -910,7 +911,7 @@ def kernel_classifier_distance_and_std(real_images, gen_a = array_ops.concat(array_ops.unstack(gen_a), 0) return kernel_classifier_distance_and_std_from_activations( - real_a, gen_a, max_block_size=max_block_size) + real_a, gen_a, max_block_size, dtype) kernel_inception_distance_and_std = functools.partial( @@ -967,14 +968,14 @@ def kernel_classifier_distance_from_activations(real_activations, into blocks for computational efficiency. Larger values are more computationally expensive but decrease the variance of the distance estimate. - dtype: if not None, coerce activations to this dtype before computations. + dtype: If not None, coerce activations to this dtype before computations. Returns: The Kernel Inception Distance. A floating-point scalar of the same type as the output of the activations. """ return kernel_classifier_distance_and_std_from_activations( - real_activations, generated_activations, max_block_size=max_block_size)[0] + real_activations, generated_activations, max_block_size, dtype)[0] def kernel_classifier_distance_and_std_from_activations(real_activations, @@ -1029,7 +1030,7 @@ def kernel_classifier_distance_and_std_from_activations(real_activations, computationally expensive but decrease the variance of the distance estimate. Having a smaller block size also gives a better estimate of the standard error. - dtype: if not None, coerce activations to this dtype before computations. + dtype: If not None, coerce activations to this dtype before computations. Returns: The Kernel Inception Distance. A floating-point scalar of the same type @@ -1080,7 +1081,7 @@ def kernel_classifier_distance_and_std_from_activations(real_activations, dim = math_ops.cast(real_activations.shape[1], dtype) def compute_kid_block(i): - 'Compute the ith block of the KID estimate.' + """Computes the ith block of the KID estimate.""" r_s = inds_r[i] r_e = inds_r[i + 1] r = real_activations[r_s:r_e] diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index dbff1d2a367e10adc607dafb4c571bb3607a3963..bd17571a0535a3c8e9dfee24a8da16eb2e72f165 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN classifier_metrics.""" +"""Tests for TF-GAN classifier_metrics.""" from __future__ import absolute_import from __future__ import division @@ -234,7 +234,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): else: logits = classifier_metrics.run_inception(img, _get_dummy_graphdef()) - self.assertTrue(isinstance(logits, ops.Tensor)) + self.assertIsInstance(logits, ops.Tensor) logits.shape.assert_is_compatible_with([batch_size, 1001]) # Check that none of the model variables are trainable. @@ -258,7 +258,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): img, _get_dummy_graphdef(), output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) - self.assertTrue(isinstance(pool, ops.Tensor)) + self.assertIsInstance(pool, ops.Tensor) pool.shape.assert_is_compatible_with([batch_size, 2048]) # Check that none of the model variables are trainable. @@ -276,8 +276,8 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): classifier_metrics.INCEPTION_FINAL_POOL ]) - self.assertTrue(isinstance(logits, ops.Tensor)) - self.assertTrue(isinstance(pool, ops.Tensor)) + self.assertIsInstance(logits, ops.Tensor) + self.assertIsInstance(pool, ops.Tensor) logits.shape.assert_is_compatible_with([batch_size, 1001]) pool.shape.assert_is_compatible_with([batch_size, 2048]) @@ -290,7 +290,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): classifier_metrics.inception_score, array_ops.zeros([6, 299, 299, 3]), num_batches=3) - self.assertTrue(isinstance(score, ops.Tensor)) + self.assertIsInstance(score, ops.Tensor) score.shape.assert_has_rank(0) # Check that none of the model variables are trainable. @@ -302,7 +302,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): distance = _run_with_mock( classifier_metrics.frechet_inception_distance, img, img) - self.assertTrue(isinstance(distance, ops.Tensor)) + self.assertIsInstance(distance, ops.Tensor) distance.shape.assert_has_rank(0) # Check that none of the model variables are trainable. @@ -314,7 +314,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): distance = _run_with_mock(classifier_metrics.kernel_inception_distance, img, img) - self.assertTrue(isinstance(distance, ops.Tensor)) + self.assertIsInstance(distance, ops.Tensor) distance.shape.assert_has_rank(0) # Check that none of the model variables are trainable. diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py index 523968bed91f1021ae629bf52c405cf5c2d7b917..326fcb3cdbf2eda66207f134cd2926f09a216a99 100644 --- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Model evaluation tools for TFGAN.""" +"""Model evaluation tools for TF-GAN.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/eval/python/summaries.py b/tensorflow/contrib/gan/python/eval/python/summaries.py index ecfdb39499b1e824e02415c0db1de3157e4f3216..1b202dfc97304ddc7ced42d65366aaf419439392 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Common TFGAN summaries.""" +"""Common TF-GAN summaries.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py index f9995bb19d0d09eaf6fd96d039b0bba1d3a7055c..9f448d3a1602c503093214201bdc96fc9bee85b5 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Common TFGAN summaries.""" +"""Common TF-GAN summaries.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py index 54a6f8d4d9086ad7fc8db31032677628561e48e8..53fc7cb8ede698c2d8590c7fd3016a884cef9be9 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN summaries.""" +"""Tests for TF-GAN summaries.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/features/__init__.py b/tensorflow/contrib/gan/python/features/__init__.py index 4816daf760143af9f1502873b123ffad8e5ec8ce..410c3a02052cd3a07a36a0ba332a80b3c2705d89 100644 --- a/tensorflow/contrib/gan/python/features/__init__.py +++ b/tensorflow/contrib/gan/python/features/__init__.py @@ -27,11 +27,13 @@ from __future__ import print_function from tensorflow.contrib.gan.python.features.python import clip_weights from tensorflow.contrib.gan.python.features.python import conditioning_utils from tensorflow.contrib.gan.python.features.python import random_tensor_pool +from tensorflow.contrib.gan.python.features.python import spectral_normalization from tensorflow.contrib.gan.python.features.python import virtual_batchnorm from tensorflow.contrib.gan.python.features.python.clip_weights import * from tensorflow.contrib.gan.python.features.python.conditioning_utils import * from tensorflow.contrib.gan.python.features.python.random_tensor_pool import * +from tensorflow.contrib.gan.python.features.python.spectral_normalization import * from tensorflow.contrib.gan.python.features.python.virtual_batchnorm import * # pylint: enable=unused-import,wildcard-import @@ -40,5 +42,6 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = clip_weights.__all__ _allowed_symbols += conditioning_utils.__all__ _allowed_symbols += random_tensor_pool.__all__ +_allowed_symbols += spectral_normalization.__all__ _allowed_symbols += virtual_batchnorm.__all__ remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..54d3d0a218dec3588844333cd47e1f92489d8df9 --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/spectral_normalization.py @@ -0,0 +1,32 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras-like layers and utilities that implement Spectral Normalization. + +Based on "Spectral Normalization for Generative Adversarial Networks" by Miyato, +et al in ICLR 2018. https://openreview.net/pdf?id=B1QRgziT- +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.features.python import spectral_normalization_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.features.python.spectral_normalization_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = spectral_normalization_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..0cc653f0a7907f407e66add5537d1e0a5adb6d8b --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py @@ -0,0 +1,315 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras-like layers and utilities that implement Spectral Normalization. + +Based on "Spectral Normalization for Generative Adversarial Networks" by Miyato, +et al in ICLR 2018. https://openreview.net/pdf?id=B1QRgziT- +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import numbers +import re + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.keras.engine import base_layer_utils as keras_base_layer_utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging + +__all__ = [ + 'compute_spectral_norm', 'spectral_normalize', 'spectral_norm_regularizer', + 'spectral_normalization_custom_getter', 'keras_spectral_normalization' +] + +# tf.bfloat16 should work, but tf.matmul converts those to tf.float32 which then +# can't directly be assigned back to the tf.bfloat16 variable. +_OK_DTYPES_FOR_SPECTRAL_NORM = (dtypes.float16, dtypes.float32, dtypes.float64) +_PERSISTED_U_VARIABLE_SUFFIX = 'spectral_norm_u' + + +def compute_spectral_norm(w_tensor, power_iteration_rounds=1, name=None): + """Estimates the largest singular value in the weight tensor. + + Args: + w_tensor: The weight matrix whose spectral norm should be computed. + power_iteration_rounds: The number of iterations of the power method to + perform. A higher number yeilds a better approximation. + name: An optional scope name. + + Returns: + The largest singular value (the spectral norm) of w. + """ + with variable_scope.variable_scope(name, 'spectral_norm'): + # The paper says to flatten convnet kernel weights from + # (C_out, C_in, KH, KW) to (C_out, C_in * KH * KW). But TensorFlow's Conv2D + # kernel weight shape is (KH, KW, C_in, C_out), so it should be reshaped to + # (KH * KW * C_in, C_out), and similarly for other layers that put output + # channels as last dimension. + # n.b. this means that w here is equivalent to w.T in the paper. + w = array_ops.reshape(w_tensor, (-1, w_tensor.get_shape()[-1])) + + # Persisted approximation of first left singular vector of matrix `w`. + u_var = variable_scope.get_variable( + _PERSISTED_U_VARIABLE_SUFFIX, + shape=(w.shape[0], 1), + dtype=w.dtype, + initializer=init_ops.random_normal_initializer(), + trainable=False) + u = u_var + + # Use power iteration method to approximate spectral norm. + for _ in range(power_iteration_rounds): + # `v` approximates the first right singular vector of matrix `w`. + v = nn.l2_normalize(math_ops.matmul(array_ops.transpose(w), u)) + u = nn.l2_normalize(math_ops.matmul(w, v)) + + # Update persisted approximation. + with ops.control_dependencies([u_var.assign(u, name='update_u')]): + u = array_ops.identity(u) + + u = array_ops.stop_gradient(u) + v = array_ops.stop_gradient(v) + + # Largest singular value of `w`. + spectral_norm = math_ops.matmul( + math_ops.matmul(array_ops.transpose(u), w), v) + spectral_norm.shape.assert_is_fully_defined() + spectral_norm.shape.assert_is_compatible_with([1, 1]) + + return spectral_norm[0][0] + + +def spectral_normalize(w, power_iteration_rounds=1, name=None): + """Normalizes a weight matrix by its spectral norm. + + Args: + w: The weight matrix to be normalized. + power_iteration_rounds: The number of iterations of the power method to + perform. A higher number yeilds a better approximation. + name: An optional scope name. + + Returns: + A normalized weight matrix tensor. + """ + with variable_scope.variable_scope(name, 'spectral_normalize'): + w_normalized = w / compute_spectral_norm( + w, power_iteration_rounds=power_iteration_rounds) + return array_ops.reshape(w_normalized, w.get_shape()) + + +def spectral_norm_regularizer(scale, power_iteration_rounds=1, scope=None): + """Returns a functions that can be used to apply spectral norm regularization. + + Small spectral norms enforce a small Lipschitz constant, which is necessary + for Wasserstein GANs. + + Args: + scale: A scalar multiplier. 0.0 disables the regularizer. + power_iteration_rounds: The number of iterations of the power method to + perform. A higher number yeilds a better approximation. + scope: An optional scope name. + + Returns: + A function with the signature `sn(weights)` that applies spectral norm + regularization. + + Raises: + ValueError: If scale is negative or if scale is not a float. + """ + if isinstance(scale, numbers.Integral): + raise ValueError('scale cannot be an integer: %s' % scale) + if isinstance(scale, numbers.Real): + if scale < 0.0: + raise ValueError( + 'Setting a scale less than 0 on a regularizer: %g' % scale) + if scale == 0.0: + logging.info('Scale of 0 disables regularizer.') + return lambda _: None + + def sn(weights, name=None): + """Applies spectral norm regularization to weights.""" + with ops.name_scope(scope, 'SpectralNormRegularizer', [weights]) as name: + scale_t = ops.convert_to_tensor( + scale, dtype=weights.dtype.base_dtype, name='scale') + return math_ops.multiply( + scale_t, + compute_spectral_norm( + weights, power_iteration_rounds=power_iteration_rounds), + name=name) + + return sn + + +def _default_name_filter(name): + """A filter function to identify common names of weight variables. + + Args: + name: The variable name. + + Returns: + Whether `name` is a standard name for a weight/kernel variables used in the + Keras, tf.layers, tf.contrib.layers or tf.contrib.slim libraries. + """ + match = re.match(r'(.*\/)?(depthwise_|pointwise_)?(weights|kernel)$', name) + return match is not None + + +def spectral_normalization_custom_getter(name_filter=_default_name_filter, + power_iteration_rounds=1): + """Custom getter that performs Spectral Normalization on a weight tensor. + + Specifically it divides the weight tensor by its largest singular value. This + is intended to stabilize GAN training, by making the discriminator satisfy a + local 1-Lipschitz constraint. + + Based on [Spectral Normalization for Generative Adversarial Networks][sn-gan]. + + [sn-gan]: https://openreview.net/forum?id=B1QRgziT- + + To reproduce an SN-GAN, apply this custom_getter to every weight tensor of + your discriminator. The last dimension of the weight tensor must be the number + of output channels. + + Apply this to layers by supplying this as the `custom_getter` of a + `tf.variable_scope`. For example: + + with tf.variable_scope('discriminator', + custom_getter=spectral_norm_getter()): + net = discriminator_fn(net) + + IMPORTANT: Keras does not respect the custom_getter supplied by the + VariableScope, so Keras users should use `keras_spectral_normalization` + instead of (or in addition to) this approach. + + It is important to carefully select to which weights you want to apply + Spectral Normalization. In general you want to normalize the kernels of + convolution and dense layers, but you do not want to normalize biases. You + also want to avoid normalizing batch normalization (and similar) variables, + but in general such layers play poorly with Spectral Normalization, since the + gamma can cancel out the normalization in other layers. By default we supply a + filter that matches the kernel variable names of the dense and convolution + layers of the tf.layers, tf.contrib.layers, tf.keras and tf.contrib.slim + libraries. If you are using anything else you'll need a custom `name_filter`. + + This custom getter internally creates a variable used to compute the spectral + norm by power iteration. It will update every time the variable is accessed, + which means the normalized discriminator weights may change slightly whilst + training the generator. Whilst unusual, this matches how the paper's authors + implement it, and in general additional rounds of power iteration can't hurt. + + Args: + name_filter: Optionally, a method that takes a Variable name as input and + returns whether this Variable should be normalized. + power_iteration_rounds: The number of iterations of the power method to + perform per step. A higher number yeilds a better approximation of the + true spectral norm. + + Returns: + A custom getter function that applies Spectral Normalization to all + Variables whose names match `name_filter`. + + Raises: + ValueError: If name_filter is not callable. + """ + if not callable(name_filter): + raise ValueError('name_filter must be callable') + + def _internal_getter(getter, name, *args, **kwargs): + """A custom getter function that applies Spectral Normalization. + + Args: + getter: The true getter to call. + name: Name of new/existing variable, in the same format as + tf.get_variable. + *args: Other positional arguments, in the same format as tf.get_variable. + **kwargs: Keyword arguments, in the same format as tf.get_variable. + + Returns: + The return value of `getter(name, *args, **kwargs)`, spectrally + normalized. + + Raises: + ValueError: If used incorrectly, or if `dtype` is not supported. + """ + if not name_filter(name): + return getter(name, *args, **kwargs) + + if name.endswith(_PERSISTED_U_VARIABLE_SUFFIX): + raise ValueError( + 'Cannot apply Spectral Normalization to internal variables created ' + 'for Spectral Normalization. Tried to normalized variable [%s]' % + name) + + if kwargs['dtype'] not in _OK_DTYPES_FOR_SPECTRAL_NORM: + raise ValueError('Disallowed data type {}'.format(kwargs['dtype'])) + + # This layer's weight Variable/PartitionedVariable. + w_tensor = getter(name, *args, **kwargs) + + if len(w_tensor.get_shape()) < 2: + raise ValueError( + 'Spectral norm can only be applied to multi-dimensional tensors') + + return spectral_normalize( + w_tensor, + power_iteration_rounds=power_iteration_rounds, + name=(name + '/spectral_normalize')) + + return _internal_getter + + +@contextlib.contextmanager +def keras_spectral_normalization(name_filter=_default_name_filter, + power_iteration_rounds=1): + """A context manager that enables Spectral Normalization for Keras. + + Keras doesn't respect the `custom_getter` in the VariableScope, so this is a + bit of a hack to make things work. + + Usage: + with keras_spectral_normalization(): + net = discriminator_fn(net) + + Args: + name_filter: Optionally, a method that takes a Variable name as input and + returns whether this Variable should be normalized. + power_iteration_rounds: The number of iterations of the power method to + perform per step. A higher number yeilds a better approximation of the + true spectral norm. + + Yields: + A context manager that wraps the standard Keras variable creation method + with the `spectral_normalization_custom_getter`. + """ + original_make_variable = keras_base_layer_utils.make_variable + sn_getter = spectral_normalization_custom_getter( + name_filter=name_filter, power_iteration_rounds=power_iteration_rounds) + + def make_variable_wrapper(name, *args, **kwargs): + return sn_getter(original_make_variable, name, *args, **kwargs) + + keras_base_layer_utils.make_variable = make_variable_wrapper + + yield + + keras_base_layer_utils.make_variable = original_make_variable diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization_test.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea21f70ec01950cfef5e4fa851c78b219d6062f --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/spectral_normalization_test.py @@ -0,0 +1,354 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for features.spectral_normalization.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib import slim +from tensorflow.contrib.gan.python.features.python import spectral_normalization_impl as spectral_normalization +from tensorflow.contrib.layers.python.layers import layers as contrib_layers +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.keras.layers import convolutional as keras_convolutional +from tensorflow.python.keras.layers import core as keras_core +from tensorflow.python.layers import convolutional as layers_convolutional +from tensorflow.python.layers import core as layers_core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class SpectralNormalizationTest(test.TestCase): + + def testComputeSpectralNorm(self): + weights = variable_scope.get_variable( + 'w', dtype=dtypes.float32, shape=[2, 3, 50, 100]) + weights = math_ops.multiply(weights, 10.0) + s = linalg_ops.svd( + array_ops.reshape(weights, [-1, weights.shape[-1]]), compute_uv=False) + true_sn = s[..., 0] + estimated_sn = spectral_normalization.compute_spectral_norm(weights) + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + np_true_sn = sess.run(true_sn) + for i in range(50): + est = sess.run(estimated_sn) + if i < 1: + np_est_1 = est + if i < 4: + np_est_5 = est + if i < 9: + np_est_10 = est + np_est_50 = est + + # Check that the estimate improves with more iterations. + self.assertAlmostEqual(np_true_sn, np_est_50, 0) + self.assertGreater( + abs(np_true_sn - np_est_10), abs(np_true_sn - np_est_50)) + self.assertGreater( + abs(np_true_sn - np_est_5), abs(np_true_sn - np_est_10)) + self.assertGreater(abs(np_true_sn - np_est_1), abs(np_true_sn - np_est_5)) + + def testSpectralNormalize(self): + weights = variable_scope.get_variable( + 'w', dtype=dtypes.float32, shape=[2, 3, 50, 100]) + weights = math_ops.multiply(weights, 10.0) + normalized_weights = spectral_normalization.spectral_normalize( + weights, power_iteration_rounds=1) + + unnormalized_sigma = linalg_ops.svd( + array_ops.reshape(weights, [-1, weights.shape[-1]]), + compute_uv=False)[..., 0] + normalized_sigma = linalg_ops.svd( + array_ops.reshape(normalized_weights, [-1, weights.shape[-1]]), + compute_uv=False)[..., 0] + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + s0 = sess.run(unnormalized_sigma) + + for i in range(50): + sigma = sess.run(normalized_sigma) + if i < 1: + s1 = sigma + if i < 5: + s5 = sigma + if i < 10: + s10 = sigma + s50 = sigma + + self.assertAlmostEqual(1., s50, 0) + self.assertGreater(abs(s10 - 1.), abs(s50 - 1.)) + self.assertGreater(abs(s5 - 1.), abs(s10 - 1.)) + self.assertGreater(abs(s1 - 1.), abs(s5 - 1.)) + self.assertGreater(abs(s0 - 1.), abs(s1 - 1.)) + + def _testLayerHelper(self, build_layer_fn, w_shape, b_shape, is_keras=False): + x = array_ops.placeholder(dtypes.float32, shape=[2, 10, 10, 3]) + + w_initial = np.random.randn(*w_shape) * 10 + w_initializer = init_ops.constant_initializer(w_initial) + b_initial = np.random.randn(*b_shape) + b_initializer = init_ops.constant_initializer(b_initial) + + if is_keras: + context_manager = spectral_normalization.keras_spectral_normalization() + else: + getter = spectral_normalization.spectral_normalization_custom_getter() + context_manager = variable_scope.variable_scope('', custom_getter=getter) + + with context_manager: + (net, + expected_normalized_vars, expected_not_normalized_vars) = build_layer_fn( + x, w_initializer, b_initializer) + + x_data = np.random.rand(*x.shape) + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + + # Before running a forward pass we still expect the variables values to + # differ from the initial value because of the normalizer. + w_befores = [] + for name, var in expected_normalized_vars.items(): + w_before = sess.run(var) + w_befores.append(w_before) + self.assertFalse( + np.allclose(w_initial, w_before), + msg=('%s appears not to be normalized. Before: %s After: %s' % + (name, w_initial, w_before))) + + # Not true for the unnormalized variables. + for name, var in expected_not_normalized_vars.items(): + b_before = sess.run(var) + self.assertTrue( + np.allclose(b_initial, b_before), + msg=('%s appears to be unexpectedly normalized. ' + 'Before: %s After: %s' % (name, b_initial, b_before))) + + # Run a bunch of forward passes. + for _ in range(1000): + _ = sess.run(net, feed_dict={x: x_data}) + + # We expect this to have improved the estimate of the spectral norm, + # which should have changed the variable values and brought them close + # to the true Spectral Normalized values. + _, s, _ = np.linalg.svd(w_initial.reshape([-1, 3])) + exactly_normalized = w_initial / s[0] + for w_before, (name, var) in zip(w_befores, + expected_normalized_vars.items()): + w_after = sess.run(var) + self.assertFalse( + np.allclose(w_before, w_after, rtol=1e-8, atol=1e-8), + msg=('%s did not improve over many iterations. ' + 'Before: %s After: %s' % (name, w_before, w_after))) + self.assertAllClose( + exactly_normalized, + w_after, + rtol=1e-4, + atol=1e-4, + msg=('Estimate of spectral norm for %s was innacurate. ' + 'Normalized matrices do not match.' + 'Estimate: %s Actual: %s' % (name, w_after, + exactly_normalized))) + + def testConv2D_Layers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + layer = layers_convolutional.Conv2D( + filters=3, + kernel_size=3, + padding='same', + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'tf.layers.Conv2d.kernel': layer.kernel} + expected_not_normalized_vars = {'tf.layers.Conv2d.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,)) + + def testConv2D_ContribLayers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['CONTRIB_LAYERS_CONV2D_WEIGHTS'], + 'biases': ['CONTRIB_LAYERS_CONV2D_BIASES'] + } + net = contrib_layers.conv2d( + x, + 3, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('CONTRIB_LAYERS_CONV2D_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('CONTRIB_LAYERS_CONV2D_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = { + 'contrib.layers.conv2d.weights': weight_vars[0] + } + expected_not_normalized_vars = { + 'contrib.layers.conv2d.bias': bias_vars[0] + } + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,)) + + def testConv2D_Slim(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['SLIM_CONV2D_WEIGHTS'], + 'biases': ['SLIM_CONV2D_BIASES'] + } + net = slim.conv2d( + x, + 3, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('SLIM_CONV2D_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('SLIM_CONV2D_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = {'slim.conv2d.weights': weight_vars[0]} + expected_not_normalized_vars = {'slim.conv2d.bias': bias_vars[0]} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,)) + + def testConv2D_Keras(self): + + def build_layer_fn(x, w_initializer, b_initializer): + layer = keras_convolutional.Conv2D( + filters=3, + kernel_size=3, + padding='same', + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'keras.layers.Conv2d.kernel': layer.kernel} + expected_not_normalized_vars = {'keras.layers.Conv2d.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,), is_keras=True) + + def testFC_Layers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + x = layers_core.Flatten()(x) + layer = layers_core.Dense( + units=3, + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'tf.layers.Dense.kernel': layer.kernel} + expected_not_normalized_vars = {'tf.layers.Dense.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,)) + + def testFC_ContribLayers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['CONTRIB_LAYERS_FC_WEIGHTS'], + 'biases': ['CONTRIB_LAYERS_FC_BIASES'] + } + x = contrib_layers.flatten(x) + net = contrib_layers.fully_connected( + x, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('CONTRIB_LAYERS_FC_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('CONTRIB_LAYERS_FC_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = { + 'contrib.layers.fully_connected.weights': weight_vars[0] + } + expected_not_normalized_vars = { + 'contrib.layers.fully_connected.bias': bias_vars[0] + } + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,)) + + def testFC_Slim(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['SLIM_FC_WEIGHTS'], + 'biases': ['SLIM_FC_BIASES'] + } + x = slim.flatten(x) + net = slim.fully_connected( + x, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('SLIM_FC_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('SLIM_FC_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = { + 'slim.fully_connected.weights': weight_vars[0] + } + expected_not_normalized_vars = {'slim.fully_connected.bias': bias_vars[0]} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,)) + + def testFC_Keras(self): + + def build_layer_fn(x, w_initializer, b_initializer): + x = keras_core.Flatten()(x) + layer = keras_core.Dense( + units=3, + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'keras.layers.Dense.kernel': layer.kernel} + expected_not_normalized_vars = {'keras.layers.Dense.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,), is_keras=True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index a0a86c6337eefa756a209635faa70db686a36247..1f1ae2df4d6def618e86aced3296ac89c836eab7 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -28,7 +28,7 @@ wasserstein_gradient_penalty All losses must be able to accept 1D or 2D Tensors, so as to be compatible with patchGAN style losses (https://arxiv.org/abs/1611.07004). -To make these losses usable in the TFGAN framework, please create a tuple +To make these losses usable in the TF-GAN framework, please create a tuple version of the losses with `losses_utils.py`. """ @@ -38,6 +38,7 @@ from __future__ import print_function from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -69,6 +70,10 @@ __all__ = [ ] +def _to_float(tensor): + return math_ops.cast(tensor, dtypes.float32) + + # Wasserstein losses from `Wasserstein GAN` (https://arxiv.org/abs/1701.07875). def wasserstein_generator_loss( discriminator_gen_outputs, @@ -98,7 +103,7 @@ def wasserstein_generator_loss( """ with ops.name_scope(scope, 'generator_wasserstein_loss', ( discriminator_gen_outputs, weights)) as scope: - discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + discriminator_gen_outputs = _to_float(discriminator_gen_outputs) loss = - discriminator_gen_outputs loss = losses.compute_weighted_loss( @@ -144,8 +149,8 @@ def wasserstein_discriminator_loss( with ops.name_scope(scope, 'discriminator_wasserstein_loss', ( discriminator_real_outputs, discriminator_gen_outputs, real_weights, generated_weights)) as scope: - discriminator_real_outputs = math_ops.to_float(discriminator_real_outputs) - discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + discriminator_real_outputs = _to_float(discriminator_real_outputs) + discriminator_gen_outputs = _to_float(discriminator_gen_outputs) discriminator_real_outputs.shape.assert_is_compatible_with( discriminator_gen_outputs.shape) @@ -320,7 +325,7 @@ def wasserstein_gradient_penalty( generated_data: Output of the generator. generator_inputs: Exact argument to pass to the generator, which is used as optional conditioning to the discriminator. - discriminator_fn: A discriminator function that conforms to TFGAN API. + discriminator_fn: A discriminator function that conforms to TF-GAN API. discriminator_scope: If not `None`, reuse discriminators from this scope. epsilon: A small positive number added for numerical stability when computing the gradient norm. @@ -647,7 +652,7 @@ def least_squares_generator_loss( """ with ops.name_scope(scope, 'lsq_generator_loss', (discriminator_gen_outputs, real_label)) as scope: - discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + discriminator_gen_outputs = _to_float(discriminator_gen_outputs) loss = math_ops.squared_difference( discriminator_gen_outputs, real_label) / 2.0 loss = losses.compute_weighted_loss( @@ -702,8 +707,8 @@ def least_squares_discriminator_loss( """ with ops.name_scope(scope, 'lsq_discriminator_loss', (discriminator_gen_outputs, real_label)) as scope: - discriminator_real_outputs = math_ops.to_float(discriminator_real_outputs) - discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + discriminator_real_outputs = _to_float(discriminator_real_outputs) + discriminator_gen_outputs = _to_float(discriminator_gen_outputs) discriminator_real_outputs.shape.assert_is_compatible_with( discriminator_gen_outputs.shape) diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py index 221c70c38bd432a6be7f6cda9c6700aa2255821f..76e57df7f646547037b3461ac44f7ee5b971406c 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN utilities for loss functions that accept GANModel namedtuples. +"""TF-GAN utilities for loss functions that accept GANModel namedtuples. The losses and penalties in this file all correspond to losses in `losses_impl.py`. Losses in that file take individual arguments, whereas in this diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index 969b68449d9c82f9f9144a8657cd8932b38fd0f7..73dfee4fdeec87cf0bac5eb675fd02a64a9ad7f5 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Named tuples for TFGAN. +"""Named tuples for TF-GAN. -TFGAN training occurs in four steps, and each step communicates with the next -step via one of these named tuples. At each step, you can either use a TFGAN +TF-GAN training occurs in four steps, and each step communicates with the next +step via one of these named tuples. At each step, you can either use a TF-GAN helper function in `train.py`, or you can manually construct a tuple. """ diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 4c7bee41b33ce1fee46d374ca5fd1c0b603762f9..f36a5d346e0f27fbbc480e876380db51ed559c09 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The TFGAN project provides a lightweight GAN training/testing framework. +"""The TF-GAN project provides a lightweight GAN training/testing framework. This file contains the core helper functions to create and train a GAN model. See the README or examples in `tensorflow_models` for details on how to use. -TFGAN training occurs in four steps: +TF-GAN training occurs in four steps: 1) Create a model 2) Add a loss 3) Create train ops @@ -645,9 +645,10 @@ def gan_loss( type(model)) # Optionally create pooled model. - pooled_model = ( - _tensor_pool_adjusted_model(model, tensor_pool_fn) - if tensor_pool_fn else model) + if tensor_pool_fn: + pooled_model = _tensor_pool_adjusted_model(model, tensor_pool_fn) + else: + pooled_model = model # Create standard losses. gen_loss = generator_loss_fn(model, add_summaries=add_summaries) @@ -665,10 +666,11 @@ def gan_loss( if _use_aux_loss(mutual_information_penalty_weight): gen_info_loss = tfgan_losses.mutual_information_penalty( model, add_summaries=add_summaries) - dis_info_loss = ( - gen_info_loss - if tensor_pool_fn is None else tfgan_losses.mutual_information_penalty( - pooled_model, add_summaries=add_summaries)) + if tensor_pool_fn is None: + dis_info_loss = gen_info_loss + else: + dis_info_loss = tfgan_losses.mutual_information_penalty( + pooled_model, add_summaries=add_summaries) gen_loss += mutual_information_penalty_weight * gen_info_loss dis_loss += mutual_information_penalty_weight * dis_info_loss if _use_aux_loss(aux_cond_generator_weight): @@ -929,7 +931,7 @@ def gan_train_ops( **kwargs): """Returns GAN train ops. - The highest-level call in TFGAN. It is composed of functions that can also + The highest-level call in TF-GAN. It is composed of functions that can also be called, should a user require more control over some part of the GAN training process. diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index 704be917b3680a1b5712f4f1dc5059b354db8610..bf8b66dcfa5e44a03107cdf1ef8b04e1dbff4a9c 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -17,11 +17,6 @@ filegroup( ]), ) -load( - "//tensorflow:tensorflow.bzl", - "tf_cuda_library", -) - # For platform specific build config load( "//tensorflow/core:platform/default/build_config.bzl", @@ -66,7 +61,6 @@ cc_library( ":gdr_memory_manager", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", - "//tensorflow/core:gpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:graph_mgr", @@ -100,15 +94,37 @@ cc_library( ], ) +cc_library( + name = "gdr_collective_executor_mgr", + srcs = ["gdr_collective_executor_mgr.cc"], + hdrs = ["gdr_collective_executor_mgr.h"], + deps = [ + ":gdr_memory_manager", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib_internal", + "//tensorflow/core/distributed_runtime:cancellable_call", + "//tensorflow/core/distributed_runtime:collective_param_resolver_distributed", + "//tensorflow/core/distributed_runtime:device_resolver_distributed", + "//tensorflow/core/distributed_runtime:request_id", + "//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr", + "//tensorflow/core/distributed_runtime:worker_cache", + ], +) + cc_library( name = "gdr_server_lib", srcs = ["gdr_server_lib.cc"], hdrs = ["gdr_server_lib.h"], linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel deps = [ + ":gdr_collective_executor_mgr", ":gdr_memory_manager", ":gdr_rendezvous_mgr", ":gdr_worker", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core/distributed_runtime:collective_param_resolver_distributed", + "//tensorflow/core/distributed_runtime:device_resolver_distributed", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", ], alwayslink = 1, diff --git a/tensorflow/contrib/gdr/README.md b/tensorflow/contrib/gdr/README.md index 8242d93f129904828a11b61d48f2df8fb0f88bc3..711adc865f37fc84550e4b45d9f0c7fff421a0dc 100644 --- a/tensorflow/contrib/gdr/README.md +++ b/tensorflow/contrib/gdr/README.md @@ -114,7 +114,16 @@ Caveats In current implementation, only tensors that reside in host memory or in GPU memory such that the GPU is adjacent to an RDMA capable NIC will use direct RDMA as its transport. When RDMA is available but not GDR, a temporary tensor copy on host memory will be used as RDMA source/destination (and copied from/to the target device). When there is no RDMA device present, it can even fallback to the original gRPC runtime. While it is theoretically possible to mix GDR enabled TF with non-GDR deployments in the same job, make sure the environment is properly setup so the GDR mode is enabled whenever possible (i.e. do not fall back to gRPC when it is not absolutely necessary). -In the original design (as in the reference), tensor buffers are only registered to NIC when we could determine that the tensor will be either a source of Send or a sink of Recv across physical machine boundary. However, to implement the precise allocations, we need to change all the devices to possibly return a NIC compatible allocator. As GDR is currently in contrib, we would like to avoid the unnecessary code disruption to the TF core, so we allocate all tensors from NIC-registered buffers using a BFC allocator. This behaviour is similar to the effect of enabling the extra GPU option `force_gpu_compatible`, which allocate all host tensors in GPU-registered buffers no matter they will be transferred from/to GPUs or not. +In the original design (as in the reference), tensor buffers are only registered +to NIC when we could determine that the tensor will be either a source of Send +or a sink of Recv across physical machine boundary. However, to implement the +precise allocations, we need to change all the devices to possibly return a NIC +compatible allocator. As GDR is currently in contrib, we would like to avoid the +unnecessary code disruption to the TF core, so we allocate all tensors from +NIC-registered buffers using a BFC allocator. This behavior is similar to the +effect of enabling the extra GPU option `force_gpu_compatible`, which allocate +all host tensors in GPU-registered buffers no matter they will be transferred +from/to GPUs or not. Reference === diff --git a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc new file mode 100644 index 0000000000000000000000000000000000000000..b84710d26eb8a64bf2f86b9f920551a8a8dbb233 --- /dev/null +++ b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc @@ -0,0 +1,160 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/gdr/gdr_collective_executor_mgr.h" + +#include "tensorflow/core/common_runtime/base_collective_executor.h" +#include "tensorflow/core/common_runtime/collective_executor_mgr.h" +#include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/distributed_runtime/cancellable_call.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/request_id.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/lib/random/random.h" + +namespace tensorflow { + +class WorkerCacheInterface; + +namespace { + +class RecvBufCall : public CancellableCall { + public: + RecvBufCall(int64 step_id, const string& peer_device, const string& peer_task, + const string& key, Device* to_device, + DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, + const DeviceLocality& client_locality, + const DeviceLocality& server_locality, + CancellationManager* cancel_mgr, WorkerCacheInterface* wc) + : CancellableCall(cancel_mgr, peer_task, wc) { + req_.set_step_id(step_id); + req_.set_buf_rendezvous_key(key); + *req_.mutable_client_locality() = client_locality; + *req_.mutable_server_locality() = server_locality; + req_.set_num_bytes(to_tensor->TotalBytes()); + req_.set_buf_ptr(reinterpret_cast(DMAHelper::base(to_tensor))); + req_.set_src_device(peer_device); + req_.set_dst_device(to_device->name()); + req_.set_request_id(GetUniqueRequestId()); + } + + ~RecvBufCall() override {} + + void IssueCall(const StatusCallback& done) override { + wi_->RecvBufAsync(&opts_, &req_, &resp_, done); + } + + RecvBufRequest req_; + RecvBufResponse resp_; +}; + +class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal { + public: + CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr, + DeviceResolverInterface* dev_resolver, + WorkerCacheInterface* worker_cache, + int64 step_id, + RemoteMemoryManager* remote_memory_manager) + : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id), + worker_cache_(worker_cache), + remote_memory_manager_(remote_memory_manager) {} + + ~CollectiveRemoteAccessDistributed() override {} + + void RecvFromPeer(const string& peer_device, const string& peer_task, + bool peer_is_local, const string& key, Device* to_device, + DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, + const DeviceLocality& client_locality, + int dev_to_dev_stream_index, + const StatusCallback& done) override { + if (peer_is_local) { + CollectiveRemoteAccessLocal::RecvFromPeer( + peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx, + to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index, + done); + return; + } + + // State that needs to be threaded through a couple of async calls + // in order to make this function completely non-blocking. + struct State { + DeviceLocality server_locality; + std::unique_ptr call; + }; + State* state = new State; + + // Logic to be executed on the RecvBufAsync callback. + auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr, + to_device_ctx, to_tensor, dev_to_dev_stream_index, + done](const Status& s) { + if (s.ok()) { + remote_memory_manager_->TensorFromTransportOptions( + to_tensor, state->call->resp_.transport_options(), to_device, + to_device_ctx, to_alloc_attr.on_host(), done); + } + if (!s.ok() && errors::IsFailedPrecondition(s)) { + dev_resolver_->ClearTask(peer_task); + } + + delete state; + }; + + // Logic to execute once we have the device locality for the server-side + // device. + auto dev_locality_callback = [this, state, peer_device, peer_task, key, + to_device, to_device_ctx, to_alloc_attr, + to_tensor, client_locality, + recv_buf_callback](const Status& s) { + if (!s.ok()) { + recv_buf_callback(s); + } else { + state->call.reset(new RecvBufCall( + step_id_, peer_device, peer_task, key, to_device, to_device_ctx, + to_alloc_attr, to_tensor, client_locality, state->server_locality, + &cancel_mgr_, worker_cache_)); + state->call->Start(recv_buf_callback); + } + }; + + dev_resolver_->GetLocalityAsync( + peer_device, peer_task, &state->server_locality, dev_locality_callback); + } + + void StartAbort(const Status& s) override { + CollectiveRemoteAccessLocal::StartAbort(s); + cancel_mgr_.StartCancel(); + } + + protected: + WorkerCacheInterface* worker_cache_; // Not owned + CancellationManager cancel_mgr_; + RemoteMemoryManager* remote_memory_manager_; +}; + +} // namespace + +CollectiveExecutor* GdrCollectiveExecutorMgr::Create(int64 step_id) { + CollectiveRemoteAccessDistributed* rma = + new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(), + worker_cache_, step_id, + remote_memory_manager_); + return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_, + &gpu_ring_order_); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.h b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.h new file mode 100644 index 0000000000000000000000000000000000000000..1417e51e82c31035f058e8e9b546e04fb0ad97b8 --- /dev/null +++ b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_ +#define TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_ + +#include "tensorflow/contrib/gdr/gdr_memory_manager.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h" +#include "tensorflow/core/framework/collective.h" + +namespace tensorflow { +class ConfigProto; +class DeviceMgr; +class WorkerCacheInterface; +class StepSequenceRequest; +class StepSequenceResponse; + +// An implementation of CollectiveExecutorMgr for a distributed environment +// that uses WorkerInterface::RecvBufAsync to route data transfers over RDMA. +class GdrCollectiveExecutorMgr : public RpcCollectiveExecutorMgr { + public: + GdrCollectiveExecutorMgr( + const ConfigProto& config, const DeviceMgr* dev_mgr, + std::unique_ptr dev_resolver, + std::unique_ptr param_resolver, + WorkerCacheInterface* worker_cache, const string& task_name, + RemoteMemoryManager* remote_memory_manager) + : RpcCollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver), + std::move(param_resolver), worker_cache, + task_name), + remote_memory_manager_(remote_memory_manager) {} + + ~GdrCollectiveExecutorMgr() override {} + + protected: + virtual CollectiveExecutor* Create(int64 step_id) override; + + private: + RemoteMemoryManager* remote_memory_manager_; // Not owned. +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_ diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc index ce1875151597f926aeb6392e7fc8307312da123f..7321e973191c4cc45f88735c6be7f2f67fe71c39 100644 --- a/tensorflow/contrib/gdr/gdr_memory_manager.cc +++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc @@ -73,7 +73,10 @@ int TryToReadNumaNode(ibv_device* device) { std::ifstream ifs(filename.c_str()); string content; - CHECK(std::getline(ifs, content)); + const auto& ret = std::getline(ifs, content); + if (!ret) { + return port::kNUMANoAffinity; + } int32 value; if (strings::safe_strto32(content, &value)) { diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc index dc0d5d548b80d36409778ef34e63171441f10142..c39cc0f9bcecc26aedfaf9707113210acf670244 100644 --- a/tensorflow/contrib/gdr/gdr_server_lib.cc +++ b/tensorflow/contrib/gdr/gdr_server_lib.cc @@ -16,11 +16,13 @@ limitations under the License. #include "tensorflow/contrib/gdr/gdr_server_lib.h" #include "grpc/support/alloc.h" +#include "tensorflow/contrib/gdr/gdr_collective_executor_mgr.h" #include "tensorflow/contrib/gdr/gdr_memory_manager.h" #include "tensorflow/contrib/gdr/gdr_rendezvous_mgr.h" #include "tensorflow/contrib/gdr/gdr_worker.h" - -#include "grpc/support/alloc.h" +#include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" namespace tensorflow { @@ -57,10 +59,34 @@ Status GdrServer::Init() { return std::unique_ptr( new GdrWorker(env, config, remote_memory_manager_.get())); }; - + CollectiveMgrCreationFunction collective_mgr_func = + [this](const ConfigProto& config, const WorkerEnv* env, + WorkerCacheInterface* worker_cache) { + string unused; + string default_worker_name; + DeviceNameUtils::SplitDeviceName( + env->device_mgr->ListDevices()[0]->name(), &default_worker_name, + &unused); + + std::unique_ptr dev_resolver( + new DeviceResolverDistributed(env->device_mgr, worker_cache, + default_worker_name)); + std::unique_ptr param_resolver( + new CollectiveParamResolverDistributed( + config, env->device_mgr, dev_resolver.get(), worker_cache, + default_worker_name)); + return new GdrCollectiveExecutorMgr( + config, env->device_mgr, std::move(dev_resolver), + std::move(param_resolver), worker_cache, default_worker_name, + remote_memory_manager_.get()); + }; TF_RETURN_IF_ERROR(remote_memory_manager_->Init()); - return GrpcServer::Init(nullptr, rendezvous_mgr_func, nullptr, worker_func); + GrpcServerOptions opts; + opts.rendezvous_mgr_func = rendezvous_mgr_func; + opts.collective_mgr_func = collective_mgr_func; + opts.worker_func = worker_func; + return GrpcServer::Init(opts); } Status GdrServer::Start() { diff --git a/tensorflow/contrib/gdr/gdr_worker.cc b/tensorflow/contrib/gdr/gdr_worker.cc index 016e5ea27b397830c69b6e1761b5994ebcfa9c3d..1204b8ca501a8f99ea6abd6c047ab2d91350bae1 100644 --- a/tensorflow/contrib/gdr/gdr_worker.cc +++ b/tensorflow/contrib/gdr/gdr_worker.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/gdr/gdr_worker.h" +#include "tensorflow/core/common_runtime/buf_rendezvous.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" @@ -29,6 +30,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_session.h" #include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" @@ -40,13 +42,13 @@ GdrWorker::GdrWorker(WorkerEnv* worker_env, const ConfigProto& config, RemoteMemoryManager* remote_memory_manager) : GrpcWorker(worker_env, config), remote_memory_manager_(remote_memory_manager), - recv_tensor_recent_request_ids_(100000) {} + recent_request_ids_(100000) {} void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, ::grpc::ByteBuffer* response, StatusCallback done) { - Status s = recv_tensor_recent_request_ids_.TrackUnique( + Status s = recent_request_ids_.TrackUnique( request->request_id(), "RecvTensor (GdrWorker)", *request); if (!s.ok()) { done(s); @@ -145,4 +147,41 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, }); } +void GdrWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, StatusCallback done) { + // This is an RDMA enabled implementation augmenting grpc. + Status s = recent_request_ids_.TrackUnique(request->request_id(), + "RecvBuf (GdrWorker)", *request); + if (!s.ok()) { + done(s); + return; + } + CollectiveExecutor::Handle ce_handle( + env_->collective_executor_mgr->FindOrCreate(request->step_id()), true); + CollectiveRemoteAccess* rma = ce_handle.get()->remote_access(); + rma->buf_rendezvous()->ConsumeBuf( + request->buf_rendezvous_key(), + [this, request, response, done](const Status& status, + BufRendezvous::Hook* hook) { + Status s = status; + if (s.ok()) { + if (!DMAHelper::CanUseDMA(hook->prod_value)) { + s = errors::Internal("Tensor value for key ", + request->buf_rendezvous_key(), + " is not of a type supported by RecvBuf"); + } + } + if (s.ok()) { + remote_memory_manager_->TransportOptionsFromTensor( + response->mutable_transport_options(), *hook->prod_value, + hook->prod_dev, hook->prod_ctx, hook->prod_attr.on_host(), + [this, response, done, hook](const Status& s) { + response->set_send_start_micros(env_->env->NowMicros()); + done(s); + BufRendezvous::DoneWithHook(hook); + }); + } + }); +} + } // namespace tensorflow diff --git a/tensorflow/contrib/gdr/gdr_worker.h b/tensorflow/contrib/gdr/gdr_worker.h index 39f11e6bde5a1ca7ae91ead02279d22d70af027b..9a85cfd4263ad86f6579eedce95969c2829ff62c 100644 --- a/tensorflow/contrib/gdr/gdr_worker.h +++ b/tensorflow/contrib/gdr/gdr_worker.h @@ -38,9 +38,13 @@ class GdrWorker : public GrpcWorker { ::grpc::ByteBuffer* response, StatusCallback done) override; + virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, + StatusCallback done) override; + private: RemoteMemoryManager* remote_memory_manager_; // Not owned - RecentRequestIds recv_tensor_recent_request_ids_; + RecentRequestIds recent_request_ids_; }; } // namespace tensorflow diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD index 0081fb61770075a2c36e92f65e01126f657edeb4..d319aa7986d81cf9ac2d1dc2e15b053a0aa0c31b 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD @@ -16,9 +16,22 @@ tf_cc_binary( srcs = ["hvx_ops_support_checker_main.cc"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/core:array_ops_op_lib", + "//tensorflow/core:candidate_sampling_ops_op_lib", + "//tensorflow/core:control_flow_ops_op_lib", "//tensorflow/core:framework_internal", + "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", + "//tensorflow/core:list_ops_op_lib", + "//tensorflow/core:manip_ops_op_lib", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:random_ops_op_lib", + "//tensorflow/core:remote_fused_graph_ops_op_lib", + "//tensorflow/core:string_ops_op_lib", + "//tensorflow/core:training_ops_op_lib", + "//tensorflow/core:user_ops_op_lib", "//tensorflow/core/kernels:remote_fused_graph_execute_utils", "//tensorflow/core/kernels/hexagon:graph_transferer", "//tensorflow/tools/graph_transforms:file_utils", diff --git a/tensorflow/contrib/ignite/README.md b/tensorflow/contrib/ignite/README.md index 5a8c650fb927be0c835aaceffc516c048195c7bf..c1f6cac4942436d32f9867d4b5557c6b9e376c69 100644 --- a/tensorflow/contrib/ignite/README.md +++ b/tensorflow/contrib/ignite/README.md @@ -30,7 +30,8 @@ system based on Apache Ignite. ## Features -Ignite Dataset provides features that that you can use in a wide range of cases. The most important and interesting features are described below. +Ignite Dataset provides features that you can use in a wide range of cases. The +most important and interesting features are described below. ### Distributed In-Memory Datasource [Apache Ignite](https://ignite.apache.org/) is a distributed in-memory database, caching, and processing platform that provides fast data access. It allows you to avoid limitations of hard drive and store and operate with as much data as you need in distributed cluster. You can utilize @@ -97,6 +98,7 @@ jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset +>>> tf.enable_eager_execution() >>> >>> dataset = IgniteDataset(cache_name="IMAGES").map(lambda obj: obj['val']['pixels']) >>> @@ -116,7 +118,15 @@ Using this ability we can calculate gradients on the nodes the data is stored on Apache Ignite uses horizontal partitioning to store data in distributed cluster. When we create Apache Ignite cache (or table in terms of SQL), we can specify the number of partitions the data will be partitioned on. For example, if an Apache Ignite cluster consists of 10 machines and we create cache with 10 partitions, then every machine will maintain approximately one data partition. -Ignite Dataset allows using these two aspects of distributed neural network training (using TensorFlow) and Apache Ignite partitioning. Ignite Dataset is a computation graph operation that can be performed on a remote worker. The remote worker can override Ignite Dataset parameters (such as `host`, `port` or `part`) by setting correstondent environment variables for worker process (such as `IGNITE_DATASET_HOST`, `IGNITE_DATASET_PORT` or `IGNITE_DATASET_PART`). Using this overriding approach, we can assign a specific partition to every worker so that one worker handles one partition and, at the same time, transparently work with single dataset. +Ignite Dataset allows using these two aspects of distributed neural network +training (using TensorFlow) and Apache Ignite partitioning. Ignite Dataset is a +computation graph operation that can be performed on a remote worker. The remote +worker can override Ignite Dataset parameters (such as `host`, `port` or `part`) +by setting correspondent environment variables for worker process (such as +`IGNITE_DATASET_HOST`, `IGNITE_DATASET_PORT` or `IGNITE_DATASET_PART`). Using +this overriding approach, we can assign a specific partition to every worker so +that one worker handles one partition and, at the same time, transparently work +with single dataset. ```python >>> import tensorflow as tf @@ -149,23 +159,31 @@ system called [IGFS](https://ignite.apache.org/features/igfs.html). IGFS delivers a similar functionality to Hadoop HDFS, but only in-memory. In fact, in addition to its own APIs, IGFS implements Hadoop FileSystem API and can be transparently plugged into Hadoop or Spark deployments. This contrib package -contains an integration between IGFS and TensorFlow. The integration is based -on [custom filesystem plugin](https://www.tensorflow.org/extend/add_filesys) -from TensorFlow side and +contains an integration between IGFS and TensorFlow. The integration is based on +[custom filesystem plugin](https://www.tensorflow.org/extend/add_filesys) from +TensorFlow side and [IGFS Native API](https://ignite.apache.org/features/igfs.html) from Apache -Ignite side. It has numerous uses, for example: * Checkpoints of state can be -saved to IGFS for reliability and fault-tolerance. * Training processes -communicate with TensorBoard by writing event files to a directory, which -TensorBoard watches. IGFS allows this communication to work even when -TensorBoard runs in a different process or machine. +Ignite side. It has numerous uses, for example: + +* Checkpoints of state can be saved to IGFS for reliability and + fault-tolerance. +* Training processes communicate with TensorBoard by writing event files to a + directory, which TensorBoard watches. IGFS allows this communication to work + even when TensorBoard runs in a different process or machine. ### SSL Connection -Apache Ignite allows to protect data transfer channels by [SSL](https://en.wikipedia.org/wiki/Transport_Layer_Security) and authentification. Ignite Dataset supports both SSL connection with and without authntication. For more information, please refer to the [Apache Ignite SSL/TLS](https://apacheignite.readme.io/docs/ssltls) documentation. +Apache Ignite allows to protect data transfer channels by +[SSL](https://en.wikipedia.org/wiki/Transport_Layer_Security) and +authentication. Ignite Dataset supports both SSL connection with and without +authentication. For more information, please refer to the +[Apache Ignite SSL/TLS](https://apacheignite.readme.io/docs/ssltls) +documentation. ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset +>>> tf.enable_eager_execution() >>> >>> dataset = IgniteDataset(cache_name="IMAGES", certfile="client.pem", @@ -186,7 +204,7 @@ Following examples will help you to easily start working with this module. The simplest way to try Ignite Dataset is to run a [Docker](https://www.docker.com/) container with Apache Ignite and loaded -[MNIST](http://yann.lecun.com/exdb/mnist/) data and after start interruct with +[MNIST](http://yann.lecun.com/exdb/mnist/) data and after start interrupt with it using Ignite Dataset. Such container is available on Docker Hub: [dmitrievanthony/ignite-with-mnist](https://hub.docker.com/r/dmitrievanthony/ignite-with-mnist/). You need to start this container on your machine: @@ -197,13 +215,13 @@ docker run -it -p 10800:10800 dmitrievanthony/ignite-with-mnist After that you will be able to work with it following way: -![ignite-dataset-mnist](https://s3.amazonaws.com/helloworld23423423ew23/ignite-dataset-mnist.png "Ignite Dataset Mnist") +![ignite-dataset-mnist](https://s3.amazonaws.com/helloworld23423423ew23/ignite-dataset-mnist-2.png "Ignite Dataset Mnist") ### IGFS The simplest way to try IGFS with TensorFlow is to run [Docker](https://www.docker.com/) container with Apache Ignite and enabled IGFS -and then interruct with it using TensorFlow +and then interrupt with it using TensorFlow [tf.gfile](https://www.tensorflow.org/api_docs/python/tf/gfile). Such container is available on Docker Hub: [dmitrievanthony/ignite-with-igfs](https://hub.docker.com/r/dmitrievanthony/ignite-with-igfs/). diff --git a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py index 66e654ca636a5a051c6f9cd35bf9001dfbcbf7f4..3ffceef8070e0fc3b3cebae2522f89fe98ce4413 100644 --- a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py +++ b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py @@ -735,8 +735,6 @@ class IgniteDataset(dataset_ops.DatasetSource): cert_password: Password to be used if the private key is encrypted and a password is necessary. """ - super(IgniteDataset, self).__init__() - with IgniteClient(host, port, username, password, certfile, keyfile, cert_password) as client: client.handshake() @@ -760,6 +758,8 @@ class IgniteDataset(dataset_ops.DatasetSource): self.cache_type.to_output_types(), self.cache_type.to_output_shapes(), self.cache_type.to_output_classes()) + super(IgniteDataset, self).__init__(self._as_variant_tensor()) + def _as_variant_tensor(self): return gen_dataset_ops.ignite_dataset(self.cache_name, self.host, self.port, self.local, self.part, self.page_size, diff --git a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py index ff5d4c458c859fd8e5e3ae65ee41a454d55d6538..89b74fbfdc38c9f42795d5c778889210baf6387f 100644 --- a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py +++ b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py @@ -19,9 +19,9 @@ from __future__ import print_function import os +from tensorflow import compat from tensorflow.contrib.ignite import IgniteDataset from tensorflow.python.client import session -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.platform import test @@ -66,7 +66,7 @@ class IgniteDatasetTest(test.TestCase): self.assertEqual(dtypes.string, dataset.output_types["val"]["NAME"]) self.assertEqual(dtypes.int64, dataset.output_types["val"]["VAL"]) - it = dataset_ops.make_one_shot_iterator(dataset) + it = compat.v1.data.make_one_shot_iterator(dataset) ne = it.get_next() with session.Session() as sess: diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh old mode 100644 new mode 100755 diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py index b399e1b6c2ac47db205b5d8bbc81875ef5c08a31..5591c3b0cc8c8bf196bb4821c018cbf155cba4ce 100644 --- a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py +++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py @@ -52,7 +52,6 @@ class KafkaDataset(dataset_ops.DatasetSource): timeout: The timeout value for the Kafka Consumer to wait (in millisecond). """ - super(KafkaDataset, self).__init__() self._topics = ops.convert_to_tensor( topics, dtype=dtypes.string, name="topics") self._servers = ops.convert_to_tensor( @@ -63,6 +62,8 @@ class KafkaDataset(dataset_ops.DatasetSource): self._timeout = ops.convert_to_tensor( timeout, dtype=dtypes.int64, name="timeout") + super(KafkaDataset, self).__init__(self._as_variant_tensor()) + def _as_variant_tensor(self): return gen_dataset_ops.kafka_dataset(self._topics, self._servers, self._group, self._eof, self._timeout) diff --git a/tensorflow/contrib/kernel_methods/python/losses.py b/tensorflow/contrib/kernel_methods/python/losses.py index 4ef0a66a52429233c6e6f70667a451466493629c..294a7d69a704b3c06ab9e30489af116929ab6c2a 100644 --- a/tensorflow/contrib/kernel_methods/python/losses.py +++ b/tensorflow/contrib/kernel_methods/python/losses.py @@ -34,7 +34,7 @@ def sparse_multiclass_hinge_loss( scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS): - """Adds Ops for computing the multiclass hinge loss. + r"""Adds Ops for computing the multiclass hinge loss. The implementation is based on the following paper: On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md index 42b91d031375b8edb7e4f364ac91ffb74ef1f54b..19daffea6c7e4486499388314d0aaaa611e94218 100644 --- a/tensorflow/contrib/kfac/README.md +++ b/tensorflow/contrib/kfac/README.md @@ -1,3 +1,3 @@ # K-FAC: Kronecker-Factored Approximate Curvature -## KFAC moved to third_party/tensorflow_kfac. +## KFAC moved to https://github.com/tensorflow/kfac. diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py index 2b1d478a9b0fd12ca25c72da6872acccfd7285fc..9479afb180df7bb4a08d6aafa4fc3bf63489d9f3 100644 --- a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py +++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py @@ -71,7 +71,6 @@ class KinesisDataset(dataset_ops.DatasetSource): interval: The interval for the Kinesis Client to wait before it tries to get records again (in millisecond). """ - super(KinesisDataset, self).__init__() self._stream = ops.convert_to_tensor( stream, dtype=dtypes.string, name="stream") self._shard = ops.convert_to_tensor( @@ -80,6 +79,7 @@ class KinesisDataset(dataset_ops.DatasetSource): read_indefinitely, dtype=dtypes.bool, name="read_indefinitely") self._interval = ops.convert_to_tensor( interval, dtype=dtypes.int64, name="interval") + super(KinesisDataset, self).__init__(self._as_variant_tensor()) def _as_variant_tensor(self): return gen_dataset_ops.kinesis_dataset( diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index 9ca6f8df5dbe3c236c4cd85095176ce69ad9deaa..69d5496f8aebb9b89c5d79f80a1a439f556093d7 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -81,6 +81,7 @@ tf_custom_op_py_library( visibility = [ "//learning/brain:__subpackages__", "//tensorflow:__subpackages__", + "//tensorflow_model_optimization:__subpackages__", "//video/youtube/personalization:__subpackages__", ], deps = [ diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index 7e6eafaa0d6f60cfc28a4c422abac0b6d5a991fb..00e41026d0038409ace178e6affd2c1cdc812122 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -1757,7 +1757,7 @@ class WeightedSumTest(test.TestCase): logits_core = fc_core.linear_model(features, [movies]) with self.cached_session() as sess: - variables_lib.initialize_all_variables().run() + variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() weights = column_to_variable[movies][0] diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index d791418c9d0f887058ceb535092fa8122da1aa75..1c0088186c030437454c0f764decab9e5a276adc 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1356,7 +1356,7 @@ class DropoutTest(test.TestCase): with self.cached_session(): images = np.random.uniform(size=(5, height, width, 3)) output = _layers.dropout(images) - self.assertEqual(output.op.name, 'Dropout/dropout_1/mul') + self.assertEqual(output.op.name, 'Dropout/dropout_1/mul_1') output.get_shape().assert_is_compatible_with( ops.convert_to_tensor(images).get_shape()) diff --git a/tensorflow/contrib/layers/python/layers/normalization.py b/tensorflow/contrib/layers/python/layers/normalization.py index 11033a2e9cb646c2e7cd2f45de1f751d88c6921a..76b03ff514821d3459f84c5f46a64d1134e0d4de 100644 --- a/tensorflow/contrib/layers/python/layers/normalization.py +++ b/tensorflow/contrib/layers/python/layers/normalization.py @@ -186,7 +186,7 @@ def group_norm(inputs, Args: inputs: A Tensor with at least 2 dimensions one which is channels. All - shape dimensions must be fully defined. + shape dimensions except for batch must be fully defined. groups: Integer. Divide the channels into this number of groups over which normalization statistics are computed. This number must be commensurate with the number of channels in `inputs`. @@ -249,13 +249,21 @@ def group_norm(inputs, """ # TODO(shlens): Support partially defined shapes for the inputs. inputs = ops.convert_to_tensor(inputs) - original_shape = inputs.shape if inputs.shape.ndims is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) if channels_axis > (inputs.shape.ndims - 1): raise ValueError('Axis is out of bounds.') + # Use dynamic shape for not fully defined dimensions in the inputs. + dyanmic_shape = array_ops.shape(inputs) + input_shape_list = [] + for i, dim in enumerate(inputs.shape): + if dim.value is None: + input_shape_list.append(dyanmic_shape[i]) + else: + input_shape_list.append(dim) + # Standardize the channels_axis to be positive and identify # of channels. if channels_axis < 0: channels_axis = inputs.shape.ndims + channels_axis @@ -289,8 +297,8 @@ def group_norm(inputs, # Determine axes before channels. Some examples of common image formats: # 'NCHW': before = [N], after = [HW] # 'NHWC': before = [NHW], after = [] - axes_before_channels = inputs.shape.as_list()[:channels_axis] - axes_after_channels = inputs.shape.as_list()[channels_axis+1:] + axes_before_channels = input_shape_list[:channels_axis] + axes_after_channels = input_shape_list[channels_axis+1:] # Manually broadcast the parameters to conform to the number of groups. params_shape_broadcast = ([1] * len(axes_before_channels) + @@ -369,7 +377,7 @@ def group_norm(inputs, outputs = inputs * gain + offset # Collapse the groups into the channel dimension. - outputs = array_ops.reshape(outputs, original_shape) + outputs = array_ops.reshape(outputs, input_shape_list) if activation_fn is not None: outputs = activation_fn(outputs) diff --git a/tensorflow/contrib/layers/python/layers/normalization_test.py b/tensorflow/contrib/layers/python/layers/normalization_test.py index c8d3c91b10dbe3b959e91182f9924b78352d370d..9a85084b239837ade87d8c778393ef8e885f5bdd 100644 --- a/tensorflow/contrib/layers/python/layers/normalization_test.py +++ b/tensorflow/contrib/layers/python/layers/normalization_test.py @@ -221,6 +221,15 @@ class GroupNormTest(test.TestCase): normalization.group_norm(inputs, channels_axis=-1, reduction_axes=[-3, -2]) + def testParamsShapeNotFullyDefinedBatchAxis(self): + height, width, groups = 3, 3, 4 + inputs = array_ops.placeholder(dtypes.float32, + shape=(None, height, width, 2*groups)) + output = normalization.group_norm(inputs, channels_axis=-1, + reduction_axes=[-3, -2], groups=groups) + self.assertListEqual([None, height, width, 2 * groups], + output.shape.as_list()) + def testCreateOp(self): height, width, groups = 3, 3, 4 images = random_ops.random_uniform((5, height, width, 2*groups), seed=1) diff --git a/tensorflow/contrib/layers/python/layers/target_column.py b/tensorflow/contrib/layers/python/layers/target_column.py index 8a6b4f68a8b33d497ddb16614a7e3cdf32f2c422..5234869718b427d7e275b76ae12021a096241a56 100644 --- a/tensorflow/contrib/layers/python/layers/target_column.py +++ b/tensorflow/contrib/layers/python/layers/target_column.py @@ -399,7 +399,7 @@ def _mean_squared_loss(logits, target): target = array_ops.expand_dims(target, axis=1) logits.get_shape().assert_is_compatible_with(target.get_shape()) - return math_ops.square(logits - math_ops.to_float(target)) + return math_ops.squared_difference(logits, math_ops.to_float(target)) def _log_loss_with_two_classes(logits, target): diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 14065fcee51c014a1af227504eaaca1fa39941e1..4749371248ee89a033912132986d7f76c85dbaa6 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -357,9 +357,9 @@ py_test( py_test( name = "dnn_linear_combined_test", - size = "large", + size = "medium", srcs = ["python/learn/estimators/dnn_linear_combined_test.py"], - shard_count = 4, + shard_count = 8, srcs_version = "PY2AND3", tags = ["no_oss"], # flaky b/70524820 deps = [ diff --git a/tensorflow/contrib/learn/README.md b/tensorflow/contrib/learn/README.md index b0bff915a993c9a01e2e6d9ef9f71c14d2f29a73..b2d3a6273abba7e3a893f30bbdd4f8b2662bd54a 100644 --- a/tensorflow/contrib/learn/README.md +++ b/tensorflow/contrib/learn/README.md @@ -111,18 +111,17 @@ Some arguments are renamed, please refer to documentation. In addition: Switch to `tf.estimator.train_and_evaluate`. Some differences: -* Most of the constructor arguments, like `train_input_fn`, `eval_input_fn`, - should be wrapped into `tf.estimator.TrainSpec` and `tf.estimator.EvalSpec`. -* Remove the `experiment_fn`. Instead, create the `Estimator`, - `train_spec` and `eval_spec`, then call `tf.estimator.train_and_evaluate` - directly. -* Inside `tf.estimator.EvalSpec`, the `exporter` field is the replacement - for `export_strategy`. To be precise, `tf.estimator.LatestExporter` is the - replacement for `tf.contrib.learn.make_export_strategy`. If you want to export - only at the end of training use `tf.estimator.FinalExporter`. -* If the `TF_CONFIG` environment variable is constructed manually, please read - the `train_and_evaluate` documentation for the new requirementds (in - particular, the chief node and evaluator node). +* Most of the constructor arguments, like `train_input_fn`, `eval_input_fn`, + should be wrapped into `tf.estimator.TrainSpec` and `tf.estimator.EvalSpec`. +* Remove the `experiment_fn`. Instead, create the `Estimator`, `train_spec` + and `eval_spec`, then call `tf.estimator.train_and_evaluate` directly. +* Inside `tf.estimator.EvalSpec`, the `exporter` field is the replacement for + `export_strategy`. To be precise, `tf.estimator.LatestExporter` is the + replacement for `tf.contrib.learn.make_export_strategy`. If you want to + export only at the end of training use `tf.estimator.FinalExporter`. +* If the `TF_CONFIG` environment variable is constructed manually, please read + the `train_and_evaluate` documentation for the new requirements (in + particular, the chief node and evaluator node). ## Others Classes and Functions diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index c1b97d8b49613ea49d9813954da3b7a63d3ba04c..4bb14a6e63b159fa4d09c9ef20947d4b125de657 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -567,7 +567,8 @@ def _mean_squared_loss(labels, logits, weights=None): if len(logits.get_shape()) == 1: logits = array_ops.expand_dims(logits, axis=1) logits.get_shape().assert_is_compatible_with(labels.get_shape()) - loss = math_ops.square(logits - math_ops.to_float(labels), name=name) + loss = math_ops.squared_difference( + logits, math_ops.to_float(labels), name=name) return _compute_weighted_loss(loss, weights) diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py index 5e90d1fa20535de3b5e25bc7ff8c3862cea5514c..318046733bf75a6d661d26f478118c8e944afe15 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py @@ -174,7 +174,7 @@ class GeneratorIoTest(test.TestCase): return np.arange(32, 36) with self.cached_session(): - with self.assertRaisesRegexp(TypeError, 'x\(\) must be generator'): + with self.assertRaisesRegexp(TypeError, r'x\(\) must be generator'): failing_input_fn = generator_io.generator_input_fn( generator, batch_size=2, shuffle=False, num_epochs=1) failing_input_fn() @@ -185,7 +185,7 @@ class GeneratorIoTest(test.TestCase): yield np.arange(32, 36) with self.cached_session(): - with self.assertRaisesRegexp(TypeError, 'x\(\) must yield dict'): + with self.assertRaisesRegexp(TypeError, r'x\(\) must yield dict'): failing_input_fn = generator_io.generator_input_fn( generator, batch_size=2, shuffle=False, num_epochs=1) failing_input_fn() diff --git a/tensorflow/contrib/learn/python/learn/utils/gc_test.py b/tensorflow/contrib/learn/python/learn/utils/gc_test.py index e7d091e18a8f186f89f5217442c24fb106c5cdab..af93e517f51ed33a8968982945ac1f65ec915ab1 100644 --- a/tensorflow/contrib/learn/python/learn/utils/gc_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/gc_test.py @@ -36,10 +36,10 @@ def _create_parser(base_dir): # Modify the path object for RegEx match for Windows Paths if os.name == "nt": match = re.match( - "^" + compat.as_str_any(base_dir).replace("\\", "/") + "/(\\d+)$", + r"^" + compat.as_str_any(base_dir).replace("\\", "/") + r"/(\d+)$", compat.as_str_any(path.path).replace("\\", "/")) else: - match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$", + match = re.match(r"^" + compat.as_str_any(base_dir) + r"/(\d+)$", compat.as_str_any(path.path)) if not match: return None diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 229a72a780d5ccce8263444ffeae7700f6ac8613..c2916b82a1cefc4615547e77fdd6f4dd48d2a600 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -28,7 +28,6 @@ from tensorflow.python.ops import lookup_ops # pylint: disable=unused-import from tensorflow.python.ops.lookup_ops import FastHashSpec from tensorflow.python.ops.lookup_ops import HasherSpec -from tensorflow.python.ops.lookup_ops import HashTable from tensorflow.python.ops.lookup_ops import IdTableWithHashBuckets from tensorflow.python.ops.lookup_ops import index_table_from_file from tensorflow.python.ops.lookup_ops import index_to_string_table_from_file @@ -288,6 +287,83 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None): return table.lookup(tensor) +class HashTable(InitializableLookupTableBase): + """A generic hash table implementation. + + Example usage: + + ```python + table = tf.HashTable( + tf.KeyValueTensorInitializer(keys, values), -1) + out = table.lookup(input_tensor) + table.init.run() + print(out.eval()) + ``` + """ + + def __init__(self, initializer, default_value, shared_name=None, name=None): + """Creates a non-initialized `HashTable` object. + + Creates a table, the type of its keys and values are specified by the + initializer. + Before using the table you will have to initialize it. After initialization + the table will be immutable. + + Args: + initializer: The table initializer to use. See `HashTable` kernel for + supported key and value types. + default_value: The value to use if a key is missing in the table. + shared_name: If non-empty, this table will be shared under the given name + across multiple sessions. + name: A name for the operation (optional). + + Returns: + A `HashTable` object. + """ + self._initializer = initializer + self._default_value = default_value + self._shared_name = shared_name + self._name = name or "hash_table" + self._table_name = None + super(HashTable, self).__init__(default_value, initializer) + self._value_shape = self._default_value.get_shape() + + def create_resource(self): + table_ref = gen_lookup_ops.hash_table_v2( + shared_name=self._shared_name, + key_dtype=self._initializer.key_dtype, + value_dtype=self._initializer.value_dtype, + name=self._name) + if context.executing_eagerly(): + self._table_name = None + else: + self._table_name = table_ref.op.name.split("/")[-1] + return table_ref + + @property + def name(self): + return self._table_name + + def export(self, name=None): + """Returns tensors of all keys and values in the table. + + Args: + name: A name for the operation (optional). + + Returns: + A pair of tensors with the first tensor containing all keys and the + second tensors containing all values in the table. + """ + with ops.name_scope(name, "%s_Export" % self.name, + [self.resource_handle]) as name: + exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( + self.resource_handle, self._key_dtype, self._value_dtype, name=name) + + exported_values.set_shape(exported_keys.get_shape().concatenate( + self._value_shape)) + return exported_keys, exported_values + + class MutableHashTable(LookupInterface): """A generic mutable hash table implementation. diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 709a042bbcefb89125f7e4cd14a0d7ecd2b53281..5ebdd0b8b50063c99e6b747c594eb99c306b4efb 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -511,7 +511,7 @@ def mean_squared_error(predictions, labels=None, weights=1.0, scope=None): predictions.get_shape().assert_is_compatible_with(labels.get_shape()) predictions = math_ops.to_float(predictions) labels = math_ops.to_float(labels) - losses = math_ops.square(math_ops.subtract(predictions, labels)) + losses = math_ops.squared_difference(predictions, labels) return compute_weighted_loss(losses, weights, scope=scope) diff --git a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py index de76acb51ffe985162a66c617b266f47c5216b19..f3b0e77740ff1d940fcd6d00b3482e90f6ebf952 100644 --- a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py +++ b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py @@ -105,7 +105,8 @@ def contrastive_loss(labels, embeddings_anchor, embeddings_positive, # Get per pair distances distances = math_ops.sqrt( math_ops.reduce_sum( - math_ops.square(embeddings_anchor - embeddings_positive), 1)) + math_ops.squared_difference(embeddings_anchor, embeddings_positive), + 1)) # Add contrastive loss for the siamese network. # label here is {0,1} for neg, pos. diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt index 87c73ec1ca610cac6d63468887bc350bada5910b..8330c45cc16ffa536107e25699379bb5d9e8993b 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt @@ -36,6 +36,7 @@ tensorflow/core/protobuf/queue_runner.pb.cc tensorflow/core/protobuf/rewriter_config.pb.cc tensorflow/core/protobuf/saver.pb.cc tensorflow/core/protobuf/tensorflow_server.pb.cc +tensorflow/core/protobuf/verifier_config.pb.cc tensorflow/core/util/event.pb.cc tensorflow/core/util/memmapped_file_system.pb.cc tensorflow/core/util/saved_tensor_slice.pb.cc diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt index 4120ea52ec5255b1efce7a6ce6890fc79c1e4831..7257ac8feedfb8ed18c4d691cd85766e70a48ae8 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt @@ -37,6 +37,7 @@ tensorflow/core/protobuf/rewriter_config.pb.h tensorflow/core/protobuf/saver.pb.h tensorflow/core/protobuf/tensor_bundle.pb.h tensorflow/core/protobuf/tensorflow_server.pb.h +tensorflow/core/protobuf/verifier_config.pb.h tensorflow/core/util/event.pb.h tensorflow/core/util/memmapped_file_system.pb.h tensorflow/core/util/saved_tensor_slice.pb.h diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 655c7eefcb978d40c8bc16a23685e03ed71bfb63..2cd7d6d519a55423a96526b541845392d9ec6bc2 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -119,6 +119,7 @@ tensorflow/core/kernels/fake_quant_ops.cc tensorflow/core/kernels/fifo_queue.cc tensorflow/core/kernels/fifo_queue_op.cc tensorflow/core/kernels/fill_functor.cc +tensorflow/core/kernels/fft_ops.cc tensorflow/core/kernels/function_ops.cc tensorflow/core/kernels/fused_batch_norm_op.cc tensorflow/core/kernels/gather_functor.cc diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt index f94d70db9046cec43073ab1406762aea1f28c8e3..13e3b6422d1989b0d499d8d20901d919554c630e 100644 --- a/tensorflow/contrib/makefile/tf_pb_text_files.txt +++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt @@ -29,5 +29,6 @@ tensorflow/core/protobuf/debug.pb_text.cc tensorflow/core/protobuf/rewriter_config.pb_text.cc tensorflow/core/protobuf/saver.pb_text.cc tensorflow/core/protobuf/tensor_bundle.pb_text.cc +tensorflow/core/protobuf/verifier_config.pb_text.cc tensorflow/core/util/memmapped_file_system.pb_text.cc tensorflow/core/util/saved_tensor_slice.pb_text.cc diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt index 2712e906d719e72dacb60e213205ad68895f905f..24d86d313b76343ed9450a33cf185d9c426696bb 100644 --- a/tensorflow/contrib/makefile/tf_proto_files.txt +++ b/tensorflow/contrib/makefile/tf_proto_files.txt @@ -43,6 +43,7 @@ tensorflow/core/protobuf/rewriter_config.proto tensorflow/core/protobuf/saver.proto tensorflow/core/protobuf/tensor_bundle.proto tensorflow/core/protobuf/tensorflow_server.proto +tensorflow/core/protobuf/verifier_config.proto tensorflow/core/util/event.proto tensorflow/core/util/memmapped_file_system.proto tensorflow/core/util/saved_tensor_slice.proto diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 7b432f8bd20989c6d95310bcaca88d44ce3e0d1f..ece246b7c28569a551f7733daf16ee1507f9c95d 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -1356,9 +1356,8 @@ def _compute_placement_auc(labels, predictions, weights, alpha, weights_0 * math_ops.square(1. - placement_values_0 - auc_0)) / (total_0 - 1. + _EPSILON)) var_1 = ( - math_ops.reduce_sum( - weights_1 * math_ops.square(placement_values_1 - auc_1)) / - (total_1 - 1. + _EPSILON)) + math_ops.reduce_sum(weights_1 * math_ops.squared_difference( + placement_values_1, auc_1)) / (total_1 - 1. + _EPSILON)) auc_std_err = math_ops.sqrt( (var_0 / (total_0 + _EPSILON)) + (var_1 / (total_1 + _EPSILON))) diff --git a/tensorflow/contrib/mpi_collectives/BUILD b/tensorflow/contrib/mpi_collectives/BUILD index ecac06354d2ce796f2a6021cdf2370d7c30ccab7..a7be92a35e0d62a61f7923ac61bb2c1267d039c6 100644 --- a/tensorflow/contrib/mpi_collectives/BUILD +++ b/tensorflow/contrib/mpi_collectives/BUILD @@ -52,7 +52,6 @@ tf_custom_op_library( deps = [ ":mpi_defines", ":mpi_message_proto_cc", - "//tensorflow/stream_executor:stream_executor_headers_lib", "//third_party/mpi", ], ) diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 0446e823d95f8ecbed6a0c34a83ade009e68448b..12320d9e456ae93cbf95639a0c9e0c7f414f3518 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -319,6 +319,9 @@ tf_py_test( "//tensorflow/python:framework_for_generated_wrappers", "//third_party/py/numpy", ], + tags = [ + "oss_serial", + ], ) tf_py_test( diff --git a/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py b/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py index 3fb649ea82e79b3bc78a2da6d5c3e9a071adec6d..0b149ed17533adff3bd7cd8fd8ff94d171f72911 100644 --- a/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py +++ b/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Adam rewrite to use global step for computing beta1 & beta2 accumulation.""" from __future__ import absolute_import from __future__ import division @@ -38,10 +37,15 @@ class AdamGSOptimizer(optimizer.Optimizer): ([pdf](http://arxiv.org/pdf/1412.6980.pdf)). """ - def __init__(self, global_step=0, learning_rate=0.001, - beta1=0.9, beta2=0.999, epsilon=1e-8, - use_locking=False, name="Adam"): - """Construct a new Adam optimizer. + def __init__(self, + global_step=0, + learning_rate=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8, + use_locking=False, + name="Adam"): + r"""Construct a new Adam optimizer. Branched from tf.train.AdamOptimizer. The only difference is to pass global step for computing beta1 and beta2 accumulators, instead of having @@ -83,23 +87,20 @@ class AdamGSOptimizer(optimizer.Optimizer): Args: global_step: tensorflow variable indicating the step. learning_rate: A Tensor or a floating point value. The learning rate. - beta1: A float value or a constant float tensor. - The exponential decay rate for the 1st moment estimates. - beta2: A float value or a constant float tensor. - The exponential decay rate for the 2nd moment estimates. + beta1: A float value or a constant float tensor. The exponential decay + rate for the 1st moment estimates. + beta2: A float value or a constant float tensor. The exponential decay + rate for the 2nd moment estimates. epsilon: A small constant for numerical stability. This epsilon is "epsilon hat" in the Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm 1 of the paper. use_locking: If True use locks for update operations. name: Optional name for the operations created when applying gradients. - Defaults to "Adam". - - @compatibility(eager) - When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and - `epsilon` can each be a callable that takes no arguments and returns the - actual value to use. This can be useful for changing these values across - different invocations of optimizer functions. - @end_compatibility + Defaults to "Adam". @compatibility(eager) When eager execution is + enabled, `learning_rate`, `beta1`, `beta2`, and `epsilon` can each be a + callable that takes no arguments and returns the actual value to use. + This can be useful for changing these values across different + invocations of optimizer functions. @end_compatibility """ super(AdamGSOptimizer, self).__init__(use_locking, name) self._lr = learning_rate @@ -115,9 +116,6 @@ class AdamGSOptimizer(optimizer.Optimizer): self._beta2_t = None self._epsilon_t = None - # Created in SparseApply if needed. - self._updated_lr = None - def _get_beta_accumulators(self): return (math_ops.pow(self._beta1_t, self._global_step_on_worker), math_ops.pow(self._beta2_t, self._global_step_on_worker)) @@ -149,28 +147,34 @@ class AdamGSOptimizer(optimizer.Optimizer): v = self.get_slot(var, "v") beta1_power, beta2_power = self._get_beta_accumulators() return training_ops.apply_adam( - var, m, v, + var, + m, + v, math_ops.cast(beta1_power, var.dtype.base_dtype), math_ops.cast(beta2_power, var.dtype.base_dtype), math_ops.cast(self._lr_t, var.dtype.base_dtype), math_ops.cast(self._beta1_t, var.dtype.base_dtype), math_ops.cast(self._beta2_t, var.dtype.base_dtype), math_ops.cast(self._epsilon_t, var.dtype.base_dtype), - grad, use_locking=self._use_locking).op + grad, + use_locking=self._use_locking).op def _resource_apply_dense(self, grad, var): m = self.get_slot(var, "m") v = self.get_slot(var, "v") beta1_power, beta2_power = self._get_beta_accumulators() return training_ops.resource_apply_adam( - var.handle, m.handle, v.handle, + var.handle, + m.handle, + v.handle, math_ops.cast(beta1_power, grad.dtype.base_dtype), math_ops.cast(beta2_power, grad.dtype.base_dtype), math_ops.cast(self._lr_t, grad.dtype.base_dtype), math_ops.cast(self._beta1_t, grad.dtype.base_dtype), math_ops.cast(self._beta2_t, grad.dtype.base_dtype), math_ops.cast(self._epsilon_t, grad.dtype.base_dtype), - grad, use_locking=self._use_locking) + grad, + use_locking=self._use_locking) def _apply_sparse_shared(self, grad, var, indices, scatter_add): beta1_power, beta2_power = self._get_beta_accumulators() @@ -184,8 +188,7 @@ class AdamGSOptimizer(optimizer.Optimizer): # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") m_scaled_g_values = grad * (1 - beta1_t) - m_t = state_ops.assign(m, m * beta1_t, - use_locking=self._use_locking) + m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking) with ops.control_dependencies([m_t]): m_t = scatter_add(m, indices, m_scaled_g_values) # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) @@ -195,23 +198,26 @@ class AdamGSOptimizer(optimizer.Optimizer): with ops.control_dependencies([v_t]): v_t = scatter_add(v, indices, v_scaled_g_values) v_sqrt = math_ops.sqrt(v_t) - var_update = state_ops.assign_sub(var, - lr * m_t / (v_sqrt + epsilon_t), - use_locking=self._use_locking) + var_update = state_ops.assign_sub( + var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t, v_t]) def _apply_sparse(self, grad, var): return self._apply_sparse_shared( - grad.values, var, grad.indices, + grad.values, + var, + grad.indices, lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda - x, i, v, use_locking=self._use_locking)) + x, + i, + v, + use_locking=self._use_locking)) def _resource_scatter_add(self, x, i, v): with ops.control_dependencies( - [resource_variable_ops.resource_scatter_add( - x.handle, i, v)]): + [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): return x.value() def _resource_apply_sparse(self, grad, var, indices): - return self._apply_sparse_shared( - grad, var, indices, self._resource_scatter_add) + return self._apply_sparse_shared(grad, var, indices, + self._resource_scatter_add) diff --git a/tensorflow/contrib/optimizer_v2/adam.py b/tensorflow/contrib/optimizer_v2/adam.py index 248ffb1f7eb5dc27112ddf9b8670344904065ed0..1b7800f324b908e3c88fe90d31a2a08cbbd5ccf2 100644 --- a/tensorflow/contrib/optimizer_v2/adam.py +++ b/tensorflow/contrib/optimizer_v2/adam.py @@ -36,7 +36,7 @@ class AdamOptimizer(optimizer_v2.OptimizerV2): def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, use_locking=False, name="Adam"): - """Construct a new Adam optimizer. + r"""Construct a new Adam optimizer. Initialization: diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index 72019b31540a943582ebb4699013d9dcfc10769f..0243927ce44aec626973744507e75b20a42253e9 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -48,7 +48,7 @@ from tensorflow.python.training.checkpointable import tracking from tensorflow.python.training.checkpointable import util -class NonLayerCheckpointable(tracking.Checkpointable): +class NonLayerCheckpointable(tracking.AutoCheckpointable): def __init__(self): super(NonLayerCheckpointable, self).__init__() @@ -440,7 +440,7 @@ class CheckpointingTests(test.TestCase): def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() - root = tracking.Checkpointable() + root = tracking.AutoCheckpointable() root.var = util.add_variable( root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) @@ -463,7 +463,7 @@ class CheckpointingTests(test.TestCase): 14.)) slots_path = util.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "with_slots")) - new_root = tracking.Checkpointable() + new_root = tracking.AutoCheckpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). slot_status = util.CheckpointableSaver( @@ -508,7 +508,7 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = tracking.Checkpointable() + obj = tracking.AutoCheckpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) @@ -526,7 +526,7 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = tracking.Checkpointable() + obj = tracking.AutoCheckpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 7fb23abc38d9dc101204ed83808aebe5a8ef1e78..1323ed014c9e51e273491694fa44a8e36cc723d0 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -843,8 +843,7 @@ class OptimizerV2(optimizer_v1.Optimizer): scale_loss_by_num_replicas = ( distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN) if scale_loss_by_num_replicas: - num_replicas = \ - distribute_ctx.get_distribution_strategy().num_replicas_in_sync + num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync if num_replicas > 1: loss_value *= 1. / num_replicas return loss_value diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index b35c4fde1a2c704880e023a0c3ac1e0766493514..b67e68ea96a15f94e62050c92405eec4fe4be70f 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -202,8 +202,9 @@ py_test( py_test( name = "quantize_parameterized_test", - size = "large", + size = "medium", srcs = ["python/quantize_parameterized_test.py"], + shard_count = 4, srcs_version = "PY2AND3", # TODO(b/118839526): Re-enable msan test. tags = [ diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md index 9085d9fa719520ac84ef6f8e07d7fa335bef5605..5b8da92491fb747c5a37dcfe03bcb21b5b903560 100644 --- a/tensorflow/contrib/quantize/README.md +++ b/tensorflow/contrib/quantize/README.md @@ -110,7 +110,7 @@ See the documentation for `tf.contrib.quantize` and [TensorFlow Lite](../lite/). ## Quantized accuracy results -The following are results of trainiing some popular CNN models (Mobilenet-v1, +The following are results of training some popular CNN models (Mobilenet-v1, Mobilenet-v2, and Inception-v3) using this tool:

diff --git a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py index 0e3c46f17d2e2a277418d39e31927db73a509670..92ae1021bc8f8fbf19ca7f7cbe208ecea18128e8 100644 --- a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py +++ b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py @@ -27,7 +27,8 @@ from tensorflow.python.platform import tf_logging as logging _UNCHANGED_RF_LAYER_OPS = [ "Add", "BiasAdd", "Cast", "Ceil", "ConcatV2", "Const", "Floor", "FusedBatchNorm", "Identity", "Log", "Mul", "Pow", "RealDiv", "Relu", - "Relu6", "Round", "Rsqrt", "Softplus", "Sub", "VariableV2", "LRN" + "Relu6", "Round", "Rsqrt", "Softplus", "Sub", "VariableV2", "LRN", + "GreaterEqual" ] # Different ways in which padding modes may be spelled. @@ -276,11 +277,11 @@ def get_layer_params(node, name_to_node, input_resolution=None, force=False): kernel_size_x, kernel_size_y = _conv_kernel_size(node, name_to_node) # Compute the padding for this node separately for each direction. total_padding_x, padding_x = _padding_size_conv_pool( - node, kernel_size_x, stride_x, input_resolution[1] - if input_resolution is not None else None) + node, kernel_size_x, stride_x, + input_resolution[1] if input_resolution is not None else None) total_padding_y, padding_y = _padding_size_conv_pool( - node, kernel_size_y, stride_y, input_resolution[0] - if input_resolution is not None else None) + node, kernel_size_y, stride_y, + input_resolution[0] if input_resolution is not None else None) elif node.op == "Pad": # Kernel and stride are simply 1 in this case. kernel_size_x = 1 @@ -294,11 +295,11 @@ def get_layer_params(node, name_to_node, input_resolution=None, force=False): kernel_size_x, kernel_size_y = _pool_kernel_size(node, name_to_node) # Compute the padding for this node separately for each direction. total_padding_x, padding_x = _padding_size_conv_pool( - node, kernel_size_x, stride_x, input_resolution[1] - if input_resolution is not None else None) + node, kernel_size_x, stride_x, + input_resolution[1] if input_resolution is not None else None) total_padding_y, padding_y = _padding_size_conv_pool( - node, kernel_size_y, stride_y, input_resolution[0] - if input_resolution is not None else None) + node, kernel_size_y, stride_y, + input_resolution[0] if input_resolution is not None else None) elif node.op in _UNCHANGED_RF_LAYER_OPS: # These nodes do not modify the RF parameters. kernel_size_x = 1 @@ -320,7 +321,7 @@ def get_layer_params(node, name_to_node, input_resolution=None, force=False): total_padding_y = None padding_y = None else: - raise ValueError("Unknown layer for operation '%s': %s" % (node.name, - node.op)) + raise ValueError( + "Unknown layer for operation '%s': %s" % (node.name, node.op)) return (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y, total_padding_x, total_padding_y) diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 44b232e0f2b26f16f0300e11cf2764e1157a0050..d65d80df8073ef70d591c4ae2af99132f1c318ef 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -227,7 +227,10 @@ tf_custom_op_library( "kernels/lstm_ops_gpu.cu.cc", "kernels/lstm_ops.h", ], - deps = ["//tensorflow/core/kernels:eigen_helpers"], + deps = [ + "//tensorflow/core/kernels:eigen_contraction_kernel", + "//tensorflow/core/kernels:eigen_helpers", + ], ) tf_gen_op_wrapper_py( @@ -249,7 +252,10 @@ tf_custom_op_library( "kernels/gru_ops_gpu.cu.cc", "kernels/gru_ops.h", ], - deps = ["//tensorflow/core/kernels:eigen_helpers"], + deps = [ + "//tensorflow/core/kernels:eigen_contraction_kernel", + "//tensorflow/core/kernels:eigen_helpers", + ], ) tf_gen_op_wrapper_py( @@ -346,6 +352,7 @@ tf_kernel_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/kernels:eigen_contraction_kernel", "//tensorflow/core/kernels:eigen_helpers", "//third_party/eigen3", ], @@ -381,6 +388,13 @@ py_binary( name = "checkpoint_convert", srcs = ["python/tools/checkpoint_convert.py"], srcs_version = "PY2AND3", + deps = [":checkpoint_convert_lib"], +) + +py_library( + name = "checkpoint_convert_lib", + srcs = ["python/tools/checkpoint_convert.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_ops", @@ -399,7 +413,7 @@ py_test( srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":checkpoint_convert", + ":checkpoint_convert_lib", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", "//tensorflow/python:session", diff --git a/tensorflow/contrib/rnn/kernels/blas_gemm.h b/tensorflow/contrib/rnn/kernels/blas_gemm.h index d37210d4b81203287fb633adc309688a35d093bb..12f3182a6a8878aa27ee143fa6405903e3fc4ef3 100644 --- a/tensorflow/contrib/rnn/kernels/blas_gemm.h +++ b/tensorflow/contrib/rnn/kernels/blas_gemm.h @@ -21,6 +21,10 @@ limitations under the License. #include "tensorflow/core/kernels/eigen_activations.h" #include "tensorflow/core/platform/types.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + namespace tensorflow { class OpKernelContext; namespace functor { diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index a0d013c618ea56077098b15b7eed5f9110239516..7bad4a60a149011d5b8d745f45359fd25473e54e 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -210,6 +210,35 @@ class RNNCellTest(test.TestCase): # Smoke test self.assertAllClose(res[0], [[0.509682, 0.509682]]) + def testSRUCellKerasRNN(self): + """Tests that SRUCell works with keras RNN layer.""" + cell = contrib_rnn_cell.SRUCell(10) + seq_input = ops.convert_to_tensor( + np.random.rand(2, 3, 5), name="seq_input", dtype=dtypes.float32) + rnn_layer = keras_layers.RNN(cell=cell) + rnn_outputs_keras = rnn_layer(seq_input) + with self.cached_session() as sess: + sess.run([variables_lib.global_variables_initializer()]) + self.assertEqual(sess.run(rnn_outputs_keras).shape, (2, 10)) + + def testSRUCellBiasType(self): + """Tests that the bias' dtype is properly set.""" + cell = contrib_rnn_cell.SRUCell(10) + cell.build((2, 3, 5)) + self.assertEqual(cell._bias.dtype, dtypes.float32_ref) + + cell = contrib_rnn_cell.SRUCell(10, dtype=dtypes.int32) + cell.build((2, 3, 5)) + self.assertEqual(cell._bias.dtype, dtypes.int32_ref) + + cell_input = ops.convert_to_tensor( + np.random.rand(2, 5), name="cell_input", dtype=dtypes.float16) + cell_state = ops.convert_to_tensor( + np.random.rand(2, 10), name="cell_state", dtype=dtypes.float16) + cell = contrib_rnn_cell.SRUCell(10) + cell(cell_input, [cell_state]) + self.assertEqual(cell._bias.dtype, dtypes.float16_ref) + def testSRUCellWithDiffSize(self): with self.cached_session() as sess: with variable_scope.variable_scope( diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index aa1d7d2b01b4595bbb03ba8e867e93db759cbd52..d7ee7fb8faacb0876218a983d68f007e1905c11e 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -29,7 +29,9 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.keras import initializers +from tensorflow.python.keras import layers as keras_layers from tensorflow.python.keras import testing_utils from tensorflow.python.keras import utils from tensorflow.python.ops import array_ops @@ -763,6 +765,17 @@ class RNNCellTest(test.TestCase): self.assertEqual(new_h.shape[1], num_proj) self.assertAllClose(np.concatenate(res[1], axis=1), expected_state) + @test_util.run_in_graph_and_eager_modes + def testNASCellKerasRNN(self): + """Tests that NASCell works with keras RNN layer.""" + cell = contrib_rnn_cell.NASCell(10) + seq_input = ops.convert_to_tensor( + np.random.rand(2, 3, 5), name="seq_input", dtype=dtypes.float32) + rnn_layer = keras_layers.RNN(cell=cell) + rnn_outputs = rnn_layer(seq_input) + self.evaluate([variables.global_variables_initializer()]) + self.assertEqual(self.evaluate(rnn_outputs).shape, (2, 10)) + def testUGRNNCell(self): num_units = 2 batch_size = 3 diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 8a1c09f171e6108174671e3122d5ff4c0b236003..482e547a16be85804beec88a91fa03b053d09b27 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -1462,7 +1462,7 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): return new_h, new_state -class NASCell(rnn_cell_impl.RNNCell): +class NASCell(rnn_cell_impl.LayerRNNCell): """Neural Architecture Search (NAS) recurrent network cell. This implements the recurrent cell from the paper: @@ -1475,23 +1475,28 @@ class NASCell(rnn_cell_impl.RNNCell): The class uses an optional projection layer. """ - def __init__(self, num_units, num_proj=None, use_biases=False, reuse=None): + # NAS cell's architecture base. + _NAS_BASE = 8 + + def __init__(self, num_units, num_proj=None, use_bias=False, reuse=None, + **kwargs): """Initialize the parameters for a NAS cell. Args: - num_units: int, The number of units in the NAS cell + num_units: int, The number of units in the NAS cell. num_proj: (optional) int, The output dimensionality for the projection matrices. If None, no projection is performed. - use_biases: (optional) bool, If True then use biases within the cell. This + use_bias: (optional) bool, If True then use biases within the cell. This is False by default. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. + **kwargs: Additional keyword arguments. """ - super(NASCell, self).__init__(_reuse=reuse) + super(NASCell, self).__init__(_reuse=reuse, **kwargs) self._num_units = num_units self._num_proj = num_proj - self._use_biases = use_biases + self._use_bias = use_bias self._reuse = reuse if num_proj is not None: @@ -1509,6 +1514,33 @@ class NASCell(rnn_cell_impl.RNNCell): def output_size(self): return self._output_size + def build(self, inputs_shape): + input_size = tensor_shape.dimension_value( + tensor_shape.TensorShape(inputs_shape).with_rank(2)[1]) + if input_size is None: + raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + + num_proj = self._num_units if self._num_proj is None else self._num_proj + + # Variables for the NAS cell. `recurrent_kernel` is all matrices multiplying + # the hiddenstate and `kernel` is all matrices multiplying the inputs. + self.recurrent_kernel = self.add_variable( + "recurrent_kernel", [num_proj, self._NAS_BASE * self._num_units]) + self.kernel = self.add_variable( + "kernel", [input_size, self._NAS_BASE * self._num_units]) + + if self._use_bias: + self.bias = self.add_variable("bias", + shape=[self._NAS_BASE * self._num_units], + initializer=init_ops.zeros_initializer) + + # Projection layer if specified + if self._num_proj is not None: + self.projection_weights = self.add_variable( + "projection_weights", [self._num_units, self._num_proj]) + + self.built = True + def call(self, inputs, state): """Run one step of NAS Cell. @@ -1535,38 +1567,20 @@ class NASCell(rnn_cell_impl.RNNCell): tanh = math_ops.tanh relu = nn_ops.relu - num_proj = self._num_units if self._num_proj is None else self._num_proj - (c_prev, m_prev) = state - dtype = inputs.dtype - input_size = inputs.get_shape().with_rank(2).dims[1] - if input_size.value is None: - raise ValueError("Could not infer input size from inputs.get_shape()[-1]") - # Variables for the NAS cell. W_m is all matrices multiplying the - # hiddenstate and W_inputs is all matrices multiplying the inputs. - concat_w_m = vs.get_variable("recurrent_kernel", - [num_proj, 8 * self._num_units], dtype) - concat_w_inputs = vs.get_variable( - "kernel", [input_size.value, 8 * self._num_units], dtype) - - m_matrix = math_ops.matmul(m_prev, concat_w_m) - inputs_matrix = math_ops.matmul(inputs, concat_w_inputs) - - if self._use_biases: - b = vs.get_variable( - "bias", - shape=[8 * self._num_units], - initializer=init_ops.zeros_initializer(), - dtype=dtype) - m_matrix = nn_ops.bias_add(m_matrix, b) + m_matrix = math_ops.matmul(m_prev, self.recurrent_kernel) + inputs_matrix = math_ops.matmul(inputs, self.kernel) + + if self._use_bias: + m_matrix = nn_ops.bias_add(m_matrix, self.bias) # The NAS cell branches into 8 different splits for both the hiddenstate # and the input m_matrix_splits = array_ops.split( - axis=1, num_or_size_splits=8, value=m_matrix) + axis=1, num_or_size_splits=self._NAS_BASE, value=m_matrix) inputs_matrix_splits = array_ops.split( - axis=1, num_or_size_splits=8, value=inputs_matrix) + axis=1, num_or_size_splits=self._NAS_BASE, value=inputs_matrix) # First layer layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0]) @@ -1598,9 +1612,7 @@ class NASCell(rnn_cell_impl.RNNCell): # Projection layer if specified if self._num_proj is not None: - concat_w_proj = vs.get_variable("projection_weights", - [self._num_units, self._num_proj], dtype) - new_m = math_ops.matmul(new_m, concat_w_proj) + new_m = math_ops.matmul(new_m, self.projection_weights) new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m) return new_m, new_state @@ -2071,7 +2083,7 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell): conv_ndims: Convolution dimensionality (1, 2 or 3). input_shape: Shape of the input as int tuple, excluding the batch size. output_channels: int, number of output channels of the conv LSTM. - kernel_shape: Shape of kernel as in tuple (of size 1,2 or 3). + kernel_shape: Shape of kernel as an int tuple (of size 1, 2 or 3). use_bias: (bool) Use bias in convolutions. skip_connection: If set to `True`, concatenate the input to the output of the conv LSTM. Default: `False`. @@ -2092,7 +2104,7 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell): self._conv_ndims = conv_ndims self._input_shape = input_shape self._output_channels = output_channels - self._kernel_shape = kernel_shape + self._kernel_shape = list(kernel_shape) self._use_bias = use_bias self._forget_bias = forget_bias self._skip_connection = skip_connection @@ -2172,7 +2184,7 @@ def _conv(args, filter_size, num_features, bias, bias_start=0.0): Args: args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D, batch x n, Tensors. - filter_size: int tuple of filter height and width. + filter_size: int tuple of filter shape (of size 1, 2 or 3). num_features: int, number of features. bias: Whether to use biases in the convolution layer. bias_start: starting value to initialize the bias; 0 by default. @@ -2744,10 +2756,12 @@ class SRUCell(rnn_cell_impl.LayerRNNCell): name: (optional) String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases. + **kwargs: Additional keyword arguments. """ - def __init__(self, num_units, activation=None, reuse=None, name=None): - super(SRUCell, self).__init__(_reuse=reuse, name=name) + def __init__(self, num_units, activation=None, reuse=None, name=None, + **kwargs): + super(SRUCell, self).__init__(_reuse=reuse, name=name, **kwargs) self._num_units = num_units self._activation = activation or math_ops.tanh @@ -2777,7 +2791,7 @@ class SRUCell(rnn_cell_impl.LayerRNNCell): self._bias = self.add_variable( rnn_cell_impl._BIAS_VARIABLE_NAME, # pylint: disable=protected-access shape=[2 * self._num_units], - initializer=init_ops.constant_initializer(0.0, dtype=self.dtype)) + initializer=init_ops.zeros_initializer) self._built = True diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py index 3fc6bfbb4d03a39906d4441e48b2788423caa234..d8ab9eba7049e468b373a1641f92dc781aa22558 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py +++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py @@ -61,10 +61,7 @@ class RpcOpTest(test.TestCase, rpc_op_test_base.RpcOpTestBase): self._server = server def tearDown(self): - # TODO(ebrevdo): Figure out why this sometimes times out. - # self._service.ExitLoop() - # self._service_thread.join() - # self._server.stop() + self._server.stop(grace=None) super(RpcOpTest, self).tearDown() diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py index 0d615923e04915a8429252317025ac8e79f9bb4e..d6148715be91c78e6e5a99fc0f3caa905b5c1a7d 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py +++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py @@ -176,7 +176,9 @@ class RpcOpTestBase(object): expected_message_values = np.where( status_code_values == errors.INVALID_ARGUMENT, I_WARNED_YOU.encode('ascii'), b'') - self.assertAllEqual(expected_message_values, status_message_values) + for msg, expected in zip(status_message_values, expected_message_values): + self.assertTrue(expected in msg, + '"%s" did not contain "%s"' % (msg, expected)) def testVecHostPortRpc(self): with self.cached_session() as sess: diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index 269443b2c6508bb618d30f64487b1a6a84e8646f..f0242a3b40fd566ec0f477d462426d5f550d1620 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -84,35 +84,6 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lib", - "//tensorflow/python:metrics", - "//tensorflow/python:platform", - "//tensorflow/python:saver", - "//tensorflow/python:util", - "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/keras:engine", - "//tensorflow/python/saved_model", - ], -) - -py_test( - name = "keras_saved_model_test", - size = "medium", - srcs = ["python/saved_model/keras_saved_model_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_oss", # TODO(b/119349471): Re-enable - "no_windows", - ], - deps = [ - ":keras_saved_model", - "//tensorflow/python:client_testlib", - "//tensorflow/python:training", - "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/keras", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py index 2a4b6eae367fe617e9a19d80f16eb3fda9ade1c0..0392ed9eee79391c60318faf68d8dfd6eb64a994 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py @@ -18,398 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import six +from tensorflow.python.keras import saving -from tensorflow.python.client import session -from tensorflow.python.framework import ops -from tensorflow.python.keras import backend as K -from tensorflow.python.keras import models as models_lib -from tensorflow.python.keras import optimizers -from tensorflow.python.keras.engine import sequential -from tensorflow.python.keras.engine import training_utils -from tensorflow.python.keras.metrics import Metric -from tensorflow.python.keras.models import model_from_json -from tensorflow.python.lib.io import file_io -from tensorflow.python.ops import variables -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.saved_model import builder as saved_model_builder -from tensorflow.python.saved_model import constants -from tensorflow.python.saved_model import save as save_lib -from tensorflow.python.saved_model import utils_impl as saved_model_utils -from tensorflow.python.training import saver as saver_lib -from tensorflow.python.training.checkpointable import util as checkpointable_utils -from tensorflow.python.util import compat -from tensorflow.python.util import nest -from tensorflow_estimator.python.estimator import keras as estimator_keras_util -from tensorflow_estimator.python.estimator import model_fn as model_fn_lib -from tensorflow_estimator.python.estimator.export import export as export_helpers - -def save_keras_model( - model, saved_model_path, custom_objects=None, as_text=None, - input_signature=None, serving_only=False): - """Saves a `tf.keras.Model` into Tensorflow SavedModel format. - - `save_model` generates new files/folders under the `saved_model_path` folder: - 1) a checkpoint containing the model weights. - 2) a saved_model.pb file containing the model's MetaGraphs. The prediction - graph is always exported. The evaluaton and training graphs are exported - if the following conditions are met: - - Evaluation: model loss is defined. - - Training: model is compiled with an optimizer defined under `tf.train`. - This is because `tf.keras.optimizers.Optimizer` instances cannot be - saved to checkpoints. - 3) Model's json configuration, if model.get_config() has been implemented. - This file can be used to reload the model using - tf.keras.models.model_from_json(). Note that if any custom objects were - used, they should be passed to the `custom_object` argument when loading - the model. - - Model limitations: - - Sequential and functional models can always be saved. - - Subclassed models can only be saved when `serving_only=True`. This is due to - the current implementation copying the model in order to export the training - and evaluation graphs. Because the topology of subclassed models cannot be - determined, the subclassed models cannot be cloned. Subclassed models will - be entirely exportable in the future. - - Note that each mode is exported in separate graphs, so different modes do not - share variables. To use the train graph with evaluation or prediction graphs, - create a new checkpoint if variable values have been updated. - - Example: - - ```python - import tensorflow as tf - - # Create a tf.keras model. - model = tf.keras.Sequential() - model.add(tf.keras.layers.Dense(1, input_shape=[10])) - model.summary() - - # Save the tf.keras model in the SavedModel format. - saved_to_path = tf.contrib.saved_model.save_keras_model( - model, '/tmp/my_simple_tf_keras_saved_model') - - # Load the saved keras model back. - model_prime = tf.contrib.saved_model.load_keras_model(saved_to_path) - model_prime.summary() - ``` - - Args: - model: A `tf.keras.Model` to be saved. If the model is subclassed, the flag - `serving_only` must be set to True. - saved_model_path: a string specifying the path to the SavedModel directory. - The SavedModel will be saved to a timestamped folder created within this - directory. - custom_objects: Optional dictionary mapping string names to custom classes - or functions (e.g. custom loss functions). - as_text: whether to write the `SavedModel` proto in text format. Currently - unavailable in serving-only mode. - input_signature: A possibly nested sequence of `tf.TensorSpec` objects, used - to specify the expected model inputs. `input_signature`'s nested structure - should match the expected nested structure of the inputs to the model. If - this is not set, this function will attempt to infer the input shapes and - dtypes from the model. Note that if the model is subclassed, the tensor - inputs to the call function should be nested in the first argument (this - is a general requirement for using subclassed models with Keras functions - .fit(), .predict(), etc.). - serving_only: Export only the outputs produced from calling the model in - predict mode. The losses, optimizer, and other training configurations are - not saved. If the SavedModel will only be used for serving (rather than - retraining), or if the model is subclassed, this can be set to True. - - Returns: - String path to the SavedModel folder, a subdirectory of `saved_model_path`. - - Raises: - NotImplementedError: If the model is a subclassed model, and serving_only is - False. - ValueError: If the input signature cannot be inferred from the model. - """ - export_dir = export_helpers.get_timestamped_export_dir(saved_model_path) - - if serving_only: - save_lib.save( - model, export_dir, - signatures=training_utils.trace_model_call(model, input_signature)) - else: - _save_v1_format(model, export_dir, custom_objects, as_text, input_signature) - - try: - _export_model_json(model, export_dir) - except NotImplementedError: - logging.warning('Skipped saving model JSON, subclassed model does not have ' - 'get_config() defined.') - - return export_dir - - -def _export_model_json(model, saved_model_path): - """Saves model configuration as a json string under assets folder.""" - model_json = model.to_json() - model_json_filepath = os.path.join( - saved_model_utils.get_or_create_assets_dir(saved_model_path), - compat.as_text(constants.SAVED_MODEL_FILENAME_JSON)) - file_io.write_string_to_file(model_json_filepath, model_json) - - -def _export_model_variables(model, saved_model_path): - """Saves model weights in checkpoint format under variables folder.""" - saved_model_utils.get_or_create_variables_dir(saved_model_path) - checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path) - model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True) - return checkpoint_prefix - - -def _save_v1_format(model, path, custom_objects, as_text, input_signature): - """Exports model to v1 SavedModel format.""" - if not model._is_graph_network: - if isinstance(model, sequential.Sequential): - # If input shape is not directly set in the model, the exported model - # will infer the expected shapes of the input from the model. - if not model.built and input_signature is None: - raise ValueError( - 'Sequential model\'s input shape is unknown. Please build the ' - 'model, or use the input_signature argument to specify the ' - 'model inputs.') - else: - raise NotImplementedError( - 'Subclassed models can only be exported for serving. Please set ' - 'argument serving_only=True.') - - builder = saved_model_builder._SavedModelBuilder(path) - - # Manually save variables to export them in an object-based checkpoint. This - # skips the `builder.add_meta_graph_and_variables()` step, which saves a - # named-based checkpoint. - # TODO(b/113134168): Add fn to Builder to save with object-based saver. - # TODO(b/113178242): This should only export the model json structure. Only - # one save is needed once the weights can be copied from the model to clone. - checkpoint_path = _export_model_variables(model, path) - - # Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that - # Keras models and `Estimator`s are exported with the same format. - # Every time a mode is exported, the code checks to see if new variables have - # been created (e.g. optimizer slot variables). If that is the case, the - # checkpoint is re-saved to include the new variables. - export_args = {'builder': builder, - 'model': model, - 'custom_objects': custom_objects, - 'checkpoint_path': checkpoint_path, - 'input_signature': input_signature} - - has_saved_vars = False - if model.optimizer: - # TODO(kathywu): Verify this works with v2 optimizer. - if isinstance(model.optimizer, optimizers.TFOptimizer): - _export_mode(model_fn_lib.ModeKeys.TRAIN, has_saved_vars, **export_args) - has_saved_vars = True - _export_mode(model_fn_lib.ModeKeys.EVAL, has_saved_vars, **export_args) - else: - logging.warning( - 'Model was compiled with an optimizer, but the optimizer is not from ' - '`tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving ' - 'graph was exported. The train and evaluate graphs were not added to ' - 'the SavedModel.') - _export_mode(model_fn_lib.ModeKeys.PREDICT, has_saved_vars, **export_args) - - builder.save(as_text) - - -def _get_var_list(model): - """Returns list of all checkpointed saveable objects in the model.""" - return checkpointable_utils.named_saveables(model) - - -def create_placeholder(spec): - return K.placeholder(shape=spec.shape, dtype=spec.dtype, name=spec.name) - - -def _export_mode( - mode, has_saved_vars, builder, model, custom_objects, checkpoint_path, - input_signature): - """Exports a model, and optionally saves new vars from the clone model. - - Args: - mode: A `tf.estimator.ModeKeys` string. - has_saved_vars: A `boolean` indicating whether the SavedModel has already - exported variables. - builder: A `SavedModelBuilder` object. - model: A `tf.keras.Model` object. - custom_objects: A dictionary mapping string names to custom classes - or functions. - checkpoint_path: String path to checkpoint. - input_signature: Nested TensorSpec containing the expected inputs. Can be - `None`, in which case the signature will be inferred from the model. - - Raises: - ValueError: If the train/eval mode is being exported, but the model does - not have an optimizer. - """ - compile_clone = (mode != model_fn_lib.ModeKeys.PREDICT) - if compile_clone and not model.optimizer: - raise ValueError( - 'Model does not have an optimizer. Cannot export mode %s' % mode) - - model_graph = ops.get_default_graph() - with ops.Graph().as_default() as g: - - K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN) - - if input_signature is None: - input_tensors = None - else: - input_tensors = nest.map_structure(create_placeholder, input_signature) - - # Clone the model into blank graph. This will create placeholders for inputs - # and targets. - clone = models_lib.clone_and_build_model( - model, input_tensors=input_tensors, custom_objects=custom_objects, - compile_clone=compile_clone) - - # Make sure that iterations variable is added to the global step collection, - # to ensure that, when the SavedModel graph is loaded, the iterations - # variable is returned by `tf.train.get_global_step()`. This is required for - # compatibility with the SavedModelEstimator. - if compile_clone: - g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations) - - # Extract update and train ops from train/test/predict functions. - train_op = None - if mode == model_fn_lib.ModeKeys.TRAIN: - clone._make_train_function() - train_op = clone.train_function.updates_op - elif mode == model_fn_lib.ModeKeys.EVAL: - clone._make_test_function() - else: - clone._make_predict_function() - g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates) - - clone_var_list = checkpointable_utils.named_saveables(clone) - - with session.Session().as_default(): - if has_saved_vars: - # Confirm all variables in the clone have an entry in the checkpoint. - status = clone.load_weights(checkpoint_path) - status.assert_existing_objects_matched() - else: - # Confirm that variables between the clone and model match up exactly, - # not counting optimizer objects. Optimizer objects are ignored because - # if the model has not trained, the slot variables will not have been - # created yet. - # TODO(b/113179535): Replace with checkpointable equivalence. - _assert_same_non_optimizer_objects(model, model_graph, clone, g) - - # TODO(b/113178242): Use value transfer for checkpointable objects. - clone.load_weights(checkpoint_path) - - # Add graph and variables to SavedModel. - # TODO(b/113134168): Switch to add_meta_graph_and_variables. - clone.save_weights(checkpoint_path, save_format='tf', overwrite=True) - builder._has_saved_variables = True - - # Add graph to the SavedModel builder. - builder.add_meta_graph( - model_fn_lib.EXPORT_TAG_MAP[mode], - signature_def_map=_create_signature_def_map(clone, mode), - saver=saver_lib.Saver(clone_var_list), - init_op=variables.local_variables_initializer(), - train_op=train_op) - return None - - -def _create_signature_def_map(model, mode): - """Creates a SignatureDef map from a Keras model.""" - inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)} - if model.optimizer: - targets_dict = {x.name.split(':')[0]: x - for x in model.targets if x is not None} - inputs_dict.update(targets_dict) - outputs_dict = {name: x - for name, x in zip(model.output_names, model.outputs)} - metrics = estimator_keras_util._convert_keras_metrics_to_estimator(model) - - # Add metric variables to the `LOCAL_VARIABLES` collection. Metric variables - # are by default not added to any collections. We are doing this here, so - # that metric variables get initialized. - local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) - vars_to_add = set() - if metrics is not None: - for key, value in six.iteritems(metrics): - if isinstance(value, Metric): - vars_to_add.update(value.variables) - # Convert Metric instances to (value_tensor, update_op) tuple. - metrics[key] = (value.result(), value.updates[0]) - # Remove variables that are in the local variables collection already. - vars_to_add = vars_to_add.difference(local_vars) - for v in vars_to_add: - ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, v) - - export_outputs = model_fn_lib.export_outputs_for_mode( - mode, - predictions=outputs_dict, - loss=model.total_loss if model.optimizer else None, - metrics=metrics) - return export_helpers.build_all_signature_defs( - inputs_dict, - export_outputs=export_outputs, - serving_only=(mode == model_fn_lib.ModeKeys.PREDICT)) - - -def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): # pylint: disable=unused-argument - """Asserts model and clone contain the same checkpointable objects.""" - - # TODO(fchollet, kathywu): make sure this works in eager mode. - return True - - -def load_keras_model(saved_model_path): - """Loads a keras.Model from SavedModel. - - load_model reinstantiates model state by: - 1) loading model topology from json (this will eventually come - from metagraph). - 2) loading model weights from checkpoint. - - Example: - - ```python - import tensorflow as tf - - # Create a tf.keras model. - model = tf.keras.Sequential() - model.add(tf.keras.layers.Dense(1, input_shape=[10])) - model.summary() - - # Save the tf.keras model in the SavedModel format. - saved_to_path = tf.contrib.saved_model.save_keras_model( - model, '/tmp/my_simple_tf_keras_saved_model') - - # Load the saved keras model back. - model_prime = tf.contrib.saved_model.load_keras_model(saved_to_path) - model_prime.summary() - ``` - - Args: - saved_model_path: a string specifying the path to an existing SavedModel. - - Returns: - a keras.Model instance. - """ - # restore model topology from json string - model_json_filepath = os.path.join( - compat.as_bytes(saved_model_path), - compat.as_bytes(constants.ASSETS_DIRECTORY), - compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON)) - model_json = file_io.read_file_to_string(model_json_filepath) - model = model_from_json(model_json) - - # restore model weights - checkpoint_prefix = os.path.join( - compat.as_text(saved_model_path), - compat.as_text(constants.VARIABLES_DIRECTORY), - compat.as_text(constants.VARIABLES_FILENAME)) - model.load_weights(checkpoint_prefix) - return model +# TODO(kathywu): Remove all contrib callers, switch to tf.keras. +save_keras_model = saving.export +load_keras_model = saving.load_from_saved_model diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py deleted file mode 100644 index fbf8138493362d4a3c8a75e1ee1bb2fbe8096499..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py +++ /dev/null @@ -1,538 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# pylint: disable=protected-access -"""Tests for saving/loading function for keras Model.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import shutil - -from absl.testing import parameterized -import numpy as np - -from tensorflow.contrib.saved_model.python.saved_model import keras_saved_model -from tensorflow.python import keras -from tensorflow.python.client import session -from tensorflow.python.eager import context -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec -from tensorflow.python.framework import test_util -from tensorflow.python.keras.engine import training -from tensorflow.python.keras.utils import tf_utils -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import test -from tensorflow.python.saved_model import loader_impl -from tensorflow.python.saved_model import signature_constants -from tensorflow.python.training import training as training_module - - -class TestModelSavingandLoading(test.TestCase): - - def _save_model_dir(self, dirname='saved_model'): - temp_dir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) - return os.path.join(temp_dir, dirname) - - def test_saving_sequential_model(self): - with self.cached_session(): - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_shape=(3,))) - model.add(keras.layers.RepeatVector(3)) - model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) - model.compile( - loss=keras.losses.MSE, - optimizer=keras.optimizers.RMSprop(lr=0.0001), - metrics=[keras.metrics.categorical_accuracy], - sample_weight_mode='temporal') - x = np.random.random((1, 3)) - y = np.random.random((1, 3, 3)) - model.train_on_batch(x, y) - - ref_y = model.predict(x) - - temp_saved_model = self._save_model_dir() - output_path = keras_saved_model.save_keras_model(model, temp_saved_model) - - loaded_model = keras_saved_model.load_keras_model(output_path) - y = loaded_model.predict(x) - self.assertAllClose(ref_y, y, atol=1e-05) - - @test_util.run_in_graph_and_eager_modes - def test_saving_sequential_model_without_compile(self): - with self.cached_session(): - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_shape=(3,))) - model.add(keras.layers.RepeatVector(3)) - model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) - - x = np.random.random((1, 3)) - ref_y = model.predict(x) - - temp_saved_model = self._save_model_dir() - output_path = keras_saved_model.save_keras_model(model, temp_saved_model) - loaded_model = keras_saved_model.load_keras_model(output_path) - - y = loaded_model.predict(x) - self.assertAllClose(ref_y, y, atol=1e-05) - - def test_saving_functional_model(self): - with self.cached_session(): - inputs = keras.layers.Input(shape=(3,)) - x = keras.layers.Dense(2)(inputs) - output = keras.layers.Dense(3)(x) - - model = keras.models.Model(inputs, output) - model.compile( - loss=keras.losses.MSE, - optimizer=keras.optimizers.RMSprop(lr=0.0001), - metrics=[keras.metrics.categorical_accuracy]) - x = np.random.random((1, 3)) - y = np.random.random((1, 3)) - model.train_on_batch(x, y) - - ref_y = model.predict(x) - - temp_saved_model = self._save_model_dir() - output_path = keras_saved_model.save_keras_model(model, temp_saved_model) - loaded_model = keras_saved_model.load_keras_model(output_path) - - y = loaded_model.predict(x) - self.assertAllClose(ref_y, y, atol=1e-05) - - @test_util.run_in_graph_and_eager_modes - def test_saving_functional_model_without_compile(self): - with self.cached_session(): - inputs = keras.layers.Input(shape=(3,)) - x = keras.layers.Dense(2)(inputs) - output = keras.layers.Dense(3)(x) - - model = keras.models.Model(inputs, output) - - x = np.random.random((1, 3)) - y = np.random.random((1, 3)) - - ref_y = model.predict(x) - - temp_saved_model = self._save_model_dir() - output_path = keras_saved_model.save_keras_model(model, temp_saved_model) - loaded_model = keras_saved_model.load_keras_model(output_path) - - y = loaded_model.predict(x) - self.assertAllClose(ref_y, y, atol=1e-05) - - @test_util.run_in_graph_and_eager_modes - def test_saving_with_tf_optimizer(self): - with self.cached_session(): - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_shape=(3,))) - model.add(keras.layers.Dense(3)) - model.compile( - loss='mse', - optimizer=training_module.RMSPropOptimizer(0.1), - metrics=['acc']) - - x = np.random.random((1, 3)) - y = np.random.random((1, 3)) - model.train_on_batch(x, y) - ref_y = model.predict(x) - - temp_saved_model = self._save_model_dir() - output_path = keras_saved_model.save_keras_model(model, temp_saved_model) - loaded_model = keras_saved_model.load_keras_model(output_path) - loaded_model.compile( - loss='mse', - optimizer=training_module.RMSPropOptimizer(0.1), - metrics=['acc']) - y = loaded_model.predict(x) - self.assertAllClose(ref_y, y, atol=1e-05) - - # test that new updates are the same with both models - x = np.random.random((1, 3)) - y = np.random.random((1, 3)) - - ref_loss = model.train_on_batch(x, y) - loss = loaded_model.train_on_batch(x, y) - self.assertAllClose(ref_loss, loss, atol=1e-05) - - ref_y = model.predict(x) - y = loaded_model.predict(x) - self.assertAllClose(ref_y, y, atol=1e-05) - - # test saving/loading again - temp_saved_model2 = self._save_model_dir('saved_model_2') - output_path2 = keras_saved_model.save_keras_model( - loaded_model, temp_saved_model2) - loaded_model = keras_saved_model.load_keras_model(output_path2) - y = loaded_model.predict(x) - self.assertAllClose(ref_y, y, atol=1e-05) - - def test_saving_subclassed_model_raise_error(self): - # For now, saving subclassed model should raise an error. It should be - # avoided later with loading from SavedModel.pb. - - class SubclassedModel(training.Model): - - def __init__(self): - super(SubclassedModel, self).__init__() - self.layer1 = keras.layers.Dense(3) - self.layer2 = keras.layers.Dense(1) - - def call(self, inp): - return self.layer2(self.layer1(inp)) - - model = SubclassedModel() - - temp_saved_model = self._save_model_dir() - with self.assertRaises(NotImplementedError): - keras_saved_model.save_keras_model(model, temp_saved_model) - - -class LayerWithLearningPhase(keras.engine.base_layer.Layer): - - def call(self, x): - phase = keras.backend.learning_phase() - output = tf_utils.smart_cond( - phase, lambda: x * 0, lambda: array_ops.identity(x)) - if not context.executing_eagerly(): - output._uses_learning_phase = True # pylint: disable=protected-access - return output - - def compute_output_shape(self, input_shape): - return input_shape - - -def functional_model(uses_learning_phase=True): - inputs = keras.layers.Input(shape=(3,)) - x = keras.layers.Dense(2)(inputs) - x = keras.layers.Dense(3)(x) - if uses_learning_phase: - x = LayerWithLearningPhase()(x) - return keras.models.Model(inputs, x) - - -def sequential_model(uses_learning_phase=True): - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_shape=(3,))) - model.add(keras.layers.Dense(3)) - if uses_learning_phase: - model.add(LayerWithLearningPhase()) - return model - - -def sequential_model_without_input_shape(uses_learning_phase=True): - model = keras.models.Sequential() - model.add(keras.layers.Dense(2)) - model.add(keras.layers.Dense(3)) - if uses_learning_phase: - model.add(LayerWithLearningPhase()) - return model - - -class Subclassed(keras.models.Model): - - def __init__(self): - super(Subclassed, self).__init__() - self.dense1 = keras.layers.Dense(2) - self.dense2 = keras.layers.Dense(3) - - def call(self, inputs): - x = self.dense1(inputs) - x = self.dense2(x) - return x - - -def subclassed_model(): - return Subclassed() - - -def load_model(sess, path, mode): - tags = model_fn_lib.EXPORT_TAG_MAP[mode] - if mode == model_fn_lib.ModeKeys.PREDICT: - sig_def_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - else: - sig_def_key = mode - - meta_graph_def = loader_impl.load(sess, tags, path) - inputs = { - k: sess.graph.get_tensor_by_name(v.name) - for k, v in meta_graph_def.signature_def[sig_def_key].inputs.items()} - outputs = { - k: sess.graph.get_tensor_by_name(v.name) - for k, v in meta_graph_def.signature_def[sig_def_key].outputs.items()} - return inputs, outputs, meta_graph_def - - -@test_util.run_all_in_graph_and_eager_modes -class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): - - def _save_model_dir(self, dirname='saved_model'): - temp_dir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) - return os.path.join(temp_dir, dirname) - - @parameterized.parameters( - { - 'model_builder': functional_model, - 'uses_learning_phase': True, - 'optimizer': training_module.AdadeltaOptimizer(), - 'train_before_export': True}, - { - 'model_builder': functional_model, - 'uses_learning_phase': True, - 'optimizer': training_module.AdadeltaOptimizer(), - 'train_before_export': False}, - { - 'model_builder': functional_model, - 'uses_learning_phase': False, - 'optimizer': None, - 'train_before_export': False}, - { - 'model_builder': sequential_model, - 'uses_learning_phase': True, - 'optimizer': training_module.AdadeltaOptimizer(), - 'train_before_export': True}, - { - 'model_builder': sequential_model, - 'uses_learning_phase': True, - 'optimizer': training_module.AdadeltaOptimizer(), - 'train_before_export': False}, - { - 'model_builder': sequential_model, - 'uses_learning_phase': False, - 'optimizer': None, - 'train_before_export': False}, - { - 'model_builder': sequential_model_without_input_shape, - 'uses_learning_phase': True, - 'optimizer': training_module.AdadeltaOptimizer(), - 'train_before_export': False}) - def testSaveAndLoadSavedModelExport( - self, model_builder, uses_learning_phase, optimizer, train_before_export): - saved_model_path = self._save_model_dir() - with self.session(graph=ops.Graph()): - np.random.seed(130) - input_arr = np.random.random((1, 3)) - target_arr = np.random.random((1, 3)) - - model = model_builder(uses_learning_phase) - if optimizer is not None: - model.compile( - loss='mse', - optimizer=optimizer, - metrics=['mae']) - if train_before_export: - model.train_on_batch(input_arr, target_arr) - - ref_loss, ref_mae = model.evaluate(input_arr, target_arr) - - ref_predict = model.predict(input_arr) - - # Export SavedModel - output_path = keras_saved_model.save_keras_model(model, saved_model_path) - - input_name = model.input_names[0] - output_name = model.output_names[0] - target_name = output_name + '_target' - - # Load predict graph, and test predictions - with session.Session(graph=ops.Graph()) as sess: - inputs, outputs, _ = load_model(sess, output_path, - model_fn_lib.ModeKeys.PREDICT) - - predictions = sess.run(outputs[output_name], - {inputs[input_name]: input_arr}) - self.assertAllClose(ref_predict, predictions, atol=1e-05) - - if optimizer: - # Load eval graph, and test predictions, loss and metric values - with session.Session(graph=ops.Graph()) as sess: - inputs, outputs, _ = load_model(sess, output_path, - model_fn_lib.ModeKeys.EVAL) - - # First obtain the loss and predictions, and run the metric update op by - # feeding in the inputs and targets. - loss, predictions, _ = sess.run( - (outputs['loss'], outputs['predictions/' + output_name], - outputs['metrics/mean_absolute_error/update_op']), { - inputs[input_name]: input_arr, - inputs[target_name]: target_arr - }) - - # The metric value should be run after the update op, to ensure that it - # reflects the correct value. - metric_value = sess.run(outputs['metrics/mean_absolute_error/value']) - - self.assertEqual(int(train_before_export), - sess.run(training_module.get_global_step())) - self.assertAllClose(ref_loss, loss, atol=1e-05) - self.assertAllClose(ref_mae, metric_value, atol=1e-05) - self.assertAllClose(ref_predict, predictions, atol=1e-05) - - # Load train graph, and check for the train op, and prediction values - with session.Session(graph=ops.Graph()) as sess: - inputs, outputs, meta_graph_def = load_model( - sess, output_path, model_fn_lib.ModeKeys.TRAIN) - self.assertEqual(int(train_before_export), - sess.run(training_module.get_global_step())) - self.assertIn('loss', outputs) - self.assertIn('metrics/mean_absolute_error/update_op', outputs) - self.assertIn('metrics/mean_absolute_error/value', outputs) - self.assertIn('predictions/' + output_name, outputs) - - # Train for a step - train_op = loader_impl.get_train_op(meta_graph_def) - train_outputs, _ = sess.run( - [outputs, train_op], {inputs[input_name]: input_arr, - inputs[target_name]: target_arr}) - self.assertEqual(int(train_before_export) + 1, - sess.run(training_module.get_global_step())) - - if uses_learning_phase: - self.assertAllClose( - [[0, 0, 0]], train_outputs['predictions/' + output_name], - atol=1e-05) - else: - self.assertNotAllClose( - [[0, 0, 0]], train_outputs['predictions/' + output_name], - atol=1e-05) - - def testSaveAndLoadSavedModelWithCustomObject(self): - saved_model_path = self._save_model_dir() - with session.Session(graph=ops.Graph()) as sess: - def relu6(x): - return keras.backend.relu(x, max_value=6) - inputs = keras.layers.Input(shape=(1,)) - outputs = keras.layers.Activation(relu6)(inputs) - model = keras.models.Model(inputs, outputs) - output_path = keras_saved_model.save_keras_model( - model, saved_model_path, custom_objects={'relu6': relu6}) - with session.Session(graph=ops.Graph()) as sess: - inputs, outputs, _ = load_model(sess, output_path, - model_fn_lib.ModeKeys.PREDICT) - input_name = model.input_names[0] - output_name = model.output_names[0] - predictions = sess.run( - outputs[output_name], {inputs[input_name]: [[7], [-3], [4]]}) - self.assertAllEqual([[6], [0], [4]], predictions) - - def testAssertModelCloneSameObjectsIgnoreOptimizer(self): - input_arr = np.random.random((1, 3)) - target_arr = np.random.random((1, 3)) - - model_graph = ops.Graph() - clone_graph = ops.Graph() - - # Create two models with the same layers but different optimizers. - with session.Session(graph=model_graph): - inputs = keras.layers.Input(shape=(3,)) - x = keras.layers.Dense(2)(inputs) - x = keras.layers.Dense(3)(x) - model = keras.models.Model(inputs, x) - - model.compile(loss='mse', optimizer=training_module.AdadeltaOptimizer()) - model.train_on_batch(input_arr, target_arr) - - with session.Session(graph=clone_graph): - inputs = keras.layers.Input(shape=(3,)) - x = keras.layers.Dense(2)(inputs) - x = keras.layers.Dense(3)(x) - clone = keras.models.Model(inputs, x) - clone.compile(loss='mse', optimizer=keras.optimizers.RMSprop(lr=0.0001)) - clone.train_on_batch(input_arr, target_arr) - - keras_saved_model._assert_same_non_optimizer_objects( - model, model_graph, clone, clone_graph) - - def testAssertModelCloneSameObjectsThrowError(self): - input_arr = np.random.random((1, 3)) - target_arr = np.random.random((1, 3)) - - model_graph = ops.Graph() - clone_graph = ops.Graph() - - # Create two models with the same layers but different optimizers. - with session.Session(graph=model_graph): - inputs = keras.layers.Input(shape=(3,)) - x = keras.layers.Dense(2)(inputs) - x = keras.layers.Dense(3)(x) - model = keras.models.Model(inputs, x) - - model.compile(loss='mse', optimizer=training_module.AdadeltaOptimizer()) - model.train_on_batch(input_arr, target_arr) - - with session.Session(graph=clone_graph): - inputs = keras.layers.Input(shape=(3,)) - x = keras.layers.Dense(2)(inputs) - x = keras.layers.Dense(4)(x) - x = keras.layers.Dense(3)(x) - clone = keras.models.Model(inputs, x) - clone.compile(loss='mse', optimizer=keras.optimizers.RMSprop(lr=0.0001)) - clone.train_on_batch(input_arr, target_arr) - - def testSaveSequentialModelWithoutInputShapes(self): - model = sequential_model_without_input_shape(True) - # A Sequential model that hasn't been built should raise an error. - with self.assertRaisesRegexp(ValueError, 'Please build the model'): - keras_saved_model.save_keras_model(model, '') - - saved_model_path = self._save_model_dir() - output_path = keras_saved_model.save_keras_model( - model, saved_model_path, - input_signature=tensor_spec.TensorSpec(shape=(10, 11, 12, 13, 14), - dtype=dtypes.float32, - name='spec_input')) - - with session.Session(graph=ops.Graph()) as sess: - inputs, outputs, _ = load_model(sess, output_path, - model_fn_lib.ModeKeys.PREDICT) - self.assertEqual(5, inputs[next(iter(inputs.keys()))].shape.ndims) - self.assertEqual(5, outputs[next(iter(outputs.keys()))].shape.ndims) - self.assertEqual(3, outputs[next(iter(outputs.keys()))].shape[-1]) - - @test_util.run_v2_only - @parameterized.parameters( - { - 'model_builder': sequential_model_without_input_shape, - 'input_signature': [tensor_spec.TensorSpec(shape=[None, 3], - dtype=dtypes.float32)]}, - { - 'model_builder': subclassed_model, - 'input_signature': [tensor_spec.TensorSpec(shape=[None, 3], - dtype=dtypes.float32)]}) - def testServingOnly(self, model_builder, input_signature): - saved_model_path = self._save_model_dir() - input_arr = np.random.random((5, 3)).astype(np.float32) - model = model_builder() - ref_predict = model.predict(input_arr) - - output_path = keras_saved_model.save_keras_model( - model, saved_model_path, serving_only=True, - input_signature=input_signature) - - # Load predict graph, and test predictions - with session.Session(graph=ops.Graph()) as sess: - inputs, outputs, _ = load_model(sess, output_path, - model_fn_lib.ModeKeys.PREDICT) - predictions = sess.run(outputs[next(iter(outputs.keys()))], - {inputs[next(iter(inputs.keys()))]: input_arr}) - self.assertAllClose(ref_predict, predictions, atol=1e-05) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD index 18b56cd21942e28cb0dc3210df0bb04d55c1e16f..7d5ba90ded215a59dbded751efd497f142a95e61 100644 --- a/tensorflow/contrib/seq2seq/BUILD +++ b/tensorflow/contrib/seq2seq/BUILD @@ -33,7 +33,6 @@ tf_custom_op_py_library( srcs_version = "PY2AND3", deps = [ ":beam_search_ops", - "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/contrib/util:util_py", @@ -59,7 +58,6 @@ tf_custom_op_py_library( "//tensorflow/python:tensor_util", "//tensorflow/python:util", "//tensorflow/python:variable_scope", - "//tensorflow/python/ops/distributions", "//third_party/py/numpy", "@six_archive//:six", ], @@ -215,3 +213,18 @@ cuda_py_test( "//tensorflow/python:variables", ], ) + +cuda_py_test( + name = "attention_wrapper_v2_test", + size = "medium", + srcs = ["python/kernel_tests/attention_wrapper_v2_test.py"], + additional_deps = [ + ":seq2seq_py", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7ff04e1780c4c44df14d6e87c5afdbf533ca5c90 --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py @@ -0,0 +1,94 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib.seq2seq.python.ops.attention_wrapper.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.seq2seq.python.ops import attention_wrapper as wrapper +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +@test_util.run_all_in_graph_and_eager_modes +class AttentionMechanismTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + super(AttentionMechanismTest, self).setUp() + self.batch = 10 + self.timestep = 5 + self.memory_size = 6 + self.units = 8 + + self.memory = ops.convert_to_tensor( + np.random.random((self.batch, self.timestep, self.memory_size)), + dtype=np.float32) + self.query = ops.convert_to_tensor( + np.random.random((self.batch, self.units)), dtype=np.float32) + self.state = ops.convert_to_tensor( + np.random.random((self.batch, self.timestep)), dtype=np.float32) + + @parameterized.named_parameters( + ("luong", wrapper.LuongAttentionV2), + ("luong_monotonic", wrapper.LuongMonotonicAttentionV2), + ("bahdanau", wrapper.BahdanauAttentionV2), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), + ) + def test_attention_shape_inference(self, attention_cls): + attention = attention_cls(self.units) + attention_score = attention([self.query, self.state, self.memory]) + self.assertLen(attention_score, 2) + self.assertEqual(attention_score[0].shape, (self.batch, self.timestep)) + self.assertEqual(attention_score[1].shape, (self.batch, self.timestep)) + + @parameterized.named_parameters( + ("luong", wrapper.LuongAttentionV2), + ("luong_monotonic", wrapper.LuongMonotonicAttentionV2), + ("bahdanau", wrapper.BahdanauAttentionV2), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), + ) + def test_get_config(self, attention_cls): + attention = attention_cls(self.units) + config = attention.get_config() + + attention_from_config = attention_cls.from_config(config) + config_from_clone = attention_from_config.get_config() + + self.assertDictEqual(config, config_from_clone) + + @parameterized.named_parameters( + ("luong", wrapper.LuongAttentionV2), + ("luong_monotonic", wrapper.LuongMonotonicAttentionV2), + ("bahdanau", wrapper.BahdanauAttentionV2), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), + ) + def test_layer_output(self, attention_cls): + attention = attention_cls(self.units) + + score = attention([self.query, self.state, self.memory]) + self.evaluate(variables.variables_initializer(attention.variables)) + + score_val = self.evaluate(score) + self.assertLen(score_val, 2) + self.assertEqual(score_val[0].shape, (self.batch, self.timestep)) + self.assertEqual(score_val[1].shape, (self.batch, self.timestep)) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py index b7f9f3fb090356a1c8d2bfb5044712ff93e267ce..abcf71c61b6e6df9462bf06323b8b11d5cc0d9a8 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py @@ -34,8 +34,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variables from tensorflow.python.ops import variable_scope -from tensorflow.python.ops.distributions import bernoulli -from tensorflow.python.ops.distributions import categorical from tensorflow.python.platform import test # pylint: enable=g-import-not-at-top @@ -517,7 +515,7 @@ class BasicDecoderTest(test.TestCase): vocabulary_size) # The sample function samples categorically from the logits. - sample_fn = lambda x: categorical.Categorical(logits=x).sample() + sample_fn = lambda x: helper_py.categorical_sample(logits=x) # The next inputs are a one-hot encoding of the sampled labels. next_inputs_fn = ( lambda x: array_ops.one_hot(x, vocabulary_size, dtype=dtypes.float32)) @@ -599,7 +597,7 @@ class BasicDecoderTest(test.TestCase): # The sample function samples independent bernoullis from the logits. sample_fn = ( - lambda x: bernoulli.Bernoulli(logits=x, dtype=dtypes.bool).sample()) + lambda x: helper_py.bernoulli_sample(logits=x, dtype=dtypes.bool)) # The next inputs are a one-hot encoding of the sampled labels. next_inputs_fn = math_ops.to_float end_fn = lambda sample_ids: sample_ids[:, end_token] diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py index 5aa32b532ffcf5772f6ace26662f5e5471cf6923..41b2a53ca5b178be9b04446c81d832575e5ed75b 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py @@ -14,80 +14,254 @@ # ============================================================================== """Tests for contrib.seq2seq.python.seq2seq.loss_ops.""" -# pylint: disable=unused-import,g-bad-import-order from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: enable=unused-import import numpy as np from tensorflow.contrib.seq2seq.python.ops import loss from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test +@test_util.run_all_in_graph_and_eager_modes class LossTest(test.TestCase): + def setUp(self): + self.batch_size = 2 + self.sequence_length = 3 + self.number_of_classes = 5 + logits = [ + constant_op.constant(i + 0.5, shape=[self.batch_size, + self.number_of_classes]) + for i in range(self.sequence_length) + ] + self.logits = array_ops.stack(logits, axis=1) + targets = [ + constant_op.constant(i, dtypes.int32, shape=[self.batch_size]) + for i in range(self.sequence_length) + ] + self.targets = array_ops.stack(targets, axis=1) + weights = [ + constant_op.constant(1.0, shape=[self.batch_size]) + for _ in range(self.sequence_length) + ] + self.weights = array_ops.stack(weights, axis=1) + # expected_loss = sparse_softmax_cross_entropy_with_logits(targets, logits) + # where targets = [0, 1, 2], and logits = [[0.5] * 5, [1.5] * 5, [2.5] * 5] + self.expected_loss = 1.60944 + def testSequenceLoss(self): - with self.session(use_gpu=True) as sess: - with variable_scope.variable_scope( - 'root', initializer=init_ops.constant_initializer(0.5)): - batch_size = 2 - sequence_length = 3 - number_of_classes = 5 - logits = [ - constant_op.constant( - i + 0.5, shape=[batch_size, number_of_classes]) - for i in range(sequence_length) - ] - logits = array_ops.stack(logits, axis=1) - targets = [ - constant_op.constant( - i, dtypes.int32, shape=[batch_size]) - for i in range(sequence_length) - ] - targets = array_ops.stack(targets, axis=1) - weights = [ - constant_op.constant( - 1.0, shape=[batch_size]) for i in range(sequence_length) - ] - weights = array_ops.stack(weights, axis=1) - - average_loss_per_example = loss.sequence_loss( - logits, targets, weights, - average_across_timesteps=True, - average_across_batch=True) - res = sess.run(average_loss_per_example) - self.assertAllClose(1.60944, res) - - average_loss_per_sequence = loss.sequence_loss( - logits, targets, weights, - average_across_timesteps=False, - average_across_batch=True) - res = sess.run(average_loss_per_sequence) - compare_per_sequence = np.ones((sequence_length)) * 1.60944 - self.assertAllClose(compare_per_sequence, res) - - average_loss_per_batch = loss.sequence_loss( - logits, targets, weights, - average_across_timesteps=True, - average_across_batch=False) - res = sess.run(average_loss_per_batch) - compare_per_batch = np.ones((batch_size)) * 1.60944 - self.assertAllClose(compare_per_batch, res) - - total_loss = loss.sequence_loss( - logits, targets, weights, - average_across_timesteps=False, - average_across_batch=False) - res = sess.run(total_loss) - compare_total = np.ones((batch_size, sequence_length)) * 1.60944 - self.assertAllClose(compare_total, res) + with self.test_session(use_gpu=True): + average_loss_per_example = loss.sequence_loss( + self.logits, self.targets, self.weights, + average_across_timesteps=True, + average_across_batch=True) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(self.expected_loss, res) + + average_loss_per_sequence = loss.sequence_loss( + self.logits, self.targets, self.weights, + average_across_timesteps=False, + average_across_batch=True) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.full((self.sequence_length), self.expected_loss) + self.assertAllClose(compare_per_sequence, res) + + average_loss_per_batch = loss.sequence_loss( + self.logits, self.targets, self.weights, + average_across_timesteps=True, + average_across_batch=False) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.full((self.batch_size), self.expected_loss) + self.assertAllClose(compare_per_batch, res) + + total_loss = loss.sequence_loss( + self.logits, self.targets, self.weights, + average_across_timesteps=False, + average_across_batch=False) + res = self.evaluate(total_loss) + compare_total = np.full((self.batch_size, self.sequence_length), + self.expected_loss) + self.assertAllClose(compare_total, res) + + def testSequenceLossClass(self): + with self.test_session(use_gpu=True): + seq_loss = loss.SequenceLoss(average_across_timesteps=True, + average_across_batch=True, + sum_over_timesteps=False, + sum_over_batch=False) + average_loss_per_example = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(self.expected_loss, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=True, + sum_over_timesteps=False, + sum_over_batch=False) + average_loss_per_sequence = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.full((self.sequence_length), self.expected_loss) + self.assertAllClose(compare_per_sequence, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=True, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=False) + average_loss_per_batch = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.full((self.batch_size), self.expected_loss) + self.assertAllClose(compare_per_batch, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=False) + total_loss = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(total_loss) + compare_total = np.full((self.batch_size, self.sequence_length), + self.expected_loss) + self.assertAllClose(compare_total, res) + + def testSumReduction(self): + with self.test_session(use_gpu=True): + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=True) + average_loss_per_example = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(self.expected_loss, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=True) + average_loss_per_sequence = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.full((self.sequence_length), self.expected_loss) + self.assertAllClose(compare_per_sequence, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=False) + average_loss_per_batch = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.full((self.batch_size), self.expected_loss) + self.assertAllClose(compare_per_batch, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=False) + total_loss = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(total_loss) + compare_total = np.full((self.batch_size, self.sequence_length), + self.expected_loss) + self.assertAllClose(compare_total, res) + + def testWeightedSumReduction(self): + weights = [ + constant_op.constant(1.0, shape=[self.batch_size]) + for _ in range(self.sequence_length) + ] + # Make the last element in the sequence to have zero weights. + weights[-1] = constant_op.constant(0.0, shape=[self.batch_size]) + self.weights = array_ops.stack(weights, axis=1) + with self.test_session(use_gpu=True): + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=True) + average_loss_per_example = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(self.expected_loss, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=True) + average_loss_per_sequence = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.full((self.sequence_length), self.expected_loss) + # The last element in every sequence are zeros, which will be filtered. + compare_per_sequence[-1] = 0. + self.assertAllClose(compare_per_sequence, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=False) + average_loss_per_batch = seq_loss(self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.full((self.batch_size), self.expected_loss) + self.assertAllClose(compare_per_batch, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=False) + total_loss = seq_loss(self.targets, self.logits, self.weights) + res = self.evaluate(total_loss) + compare_total = np.full((self.batch_size, self.sequence_length), + self.expected_loss) + # The last element in every sequence are zeros, which will be filtered. + compare_total[:, -1] = 0 + self.assertAllClose(compare_total, res) + + def testZeroWeights(self): + weights = [ + constant_op.constant(0.0, shape=[self.batch_size]) + for _ in range(self.sequence_length) + ] + weights = array_ops.stack(weights, axis=1) + with self.test_session(use_gpu=True): + average_loss_per_example = loss.sequence_loss( + self.logits, self.targets, weights, + average_across_timesteps=True, + average_across_batch=True) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(0.0, res) + + average_loss_per_sequence = loss.sequence_loss( + self.logits, self.targets, weights, + average_across_timesteps=False, + average_across_batch=True) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.zeros((self.sequence_length)) + self.assertAllClose(compare_per_sequence, res) + + average_loss_per_batch = loss.sequence_loss( + self.logits, self.targets, weights, + average_across_timesteps=True, + average_across_batch=False) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.zeros((self.batch_size)) + self.assertAllClose(compare_per_batch, res) + + total_loss = loss.sequence_loss( + self.logits, self.targets, weights, + average_across_timesteps=False, + average_across_batch=False) + res = self.evaluate(total_loss) + compare_total = np.zeros((self.batch_size, self.sequence_length)) + self.assertAllClose(compare_total, res) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 60ec3efffe771a3a6d6f36ed4b51a34ef9509612..ae3e7f1b5d8c9f06b5defbaee9cad3810e58abd4 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -28,6 +28,7 @@ from tensorflow.contrib.framework.python.framework import tensor_util from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import layers from tensorflow.python.layers import base as layers_base from tensorflow.python.layers import core as layers_core from tensorflow.python.ops import array_ops @@ -72,77 +73,6 @@ class AttentionMechanism(object): raise NotImplementedError -def _prepare_memory(memory, memory_sequence_length, check_inner_dims_defined): - """Convert to tensor and possibly mask `memory`. - - Args: - memory: `Tensor`, shaped `[batch_size, max_time, ...]`. - memory_sequence_length: `int32` `Tensor`, shaped `[batch_size]`. - check_inner_dims_defined: Python boolean. If `True`, the `memory` - argument's shape is checked to ensure all but the two outermost - dimensions are fully defined. - - Returns: - A (possibly masked), checked, new `memory`. - - Raises: - ValueError: If `check_inner_dims_defined` is `True` and not - `memory.shape[2:].is_fully_defined()`. - """ - memory = nest.map_structure( - lambda m: ops.convert_to_tensor(m, name="memory"), memory) - if memory_sequence_length is not None: - memory_sequence_length = ops.convert_to_tensor( - memory_sequence_length, name="memory_sequence_length") - if check_inner_dims_defined: - def _check_dims(m): - if not m.get_shape()[2:].is_fully_defined(): - raise ValueError("Expected memory %s to have fully defined inner dims, " - "but saw shape: %s" % (m.name, m.get_shape())) - nest.map_structure(_check_dims, memory) - if memory_sequence_length is None: - seq_len_mask = None - else: - seq_len_mask = array_ops.sequence_mask( - memory_sequence_length, - maxlen=array_ops.shape(nest.flatten(memory)[0])[1], - dtype=nest.flatten(memory)[0].dtype) - seq_len_batch_size = ( - tensor_shape.dimension_value(memory_sequence_length.shape[0]) - or array_ops.shape(memory_sequence_length)[0]) - def _maybe_mask(m, seq_len_mask): - rank = m.get_shape().ndims - rank = rank if rank is not None else array_ops.rank(m) - extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32) - m_batch_size = tensor_shape.dimension_value( - m.shape[0]) or array_ops.shape(m)[0] - if memory_sequence_length is not None: - message = ("memory_sequence_length and memory tensor batch sizes do not " - "match.") - with ops.control_dependencies([ - check_ops.assert_equal( - seq_len_batch_size, m_batch_size, message=message)]): - seq_len_mask = array_ops.reshape( - seq_len_mask, - array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0)) - return m * seq_len_mask - else: - return m - return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory) - - -def _maybe_mask_score(score, memory_sequence_length, score_mask_value): - if memory_sequence_length is None: - return score - message = ("All values in memory_sequence_length must greater than zero.") - with ops.control_dependencies( - [check_ops.assert_positive(memory_sequence_length, message=message)]): - score_mask = array_ops.sequence_mask( - memory_sequence_length, maxlen=array_ops.shape(score)[1]) - score_mask_values = score_mask_value * array_ops.ones_like(score) - return array_ops.where(score_mask, score, score_mask_values) - - class _BaseAttentionMechanism(AttentionMechanism): """A base AttentionMechanism class providing common functionality. @@ -205,12 +135,14 @@ class _BaseAttentionMechanism(AttentionMechanism): self._memory_layer.dtype).as_numpy_dtype(-np.inf) self._probability_fn = lambda score, prev: ( # pylint:disable=g-long-lambda probability_fn( - _maybe_mask_score(score, memory_sequence_length, score_mask_value), + _maybe_mask_score(score, + memory_sequence_length=memory_sequence_length, + score_mask_value=score_mask_value), prev)) with ops.name_scope( name, "BaseAttentionMechanismInit", nest.flatten(memory)): self._values = _prepare_memory( - memory, memory_sequence_length, + memory, memory_sequence_length=memory_sequence_length, check_inner_dims_defined=check_inner_dims_defined) self._keys = ( self.memory_layer(self._values) if self.memory_layer # pylint: disable=not-callable @@ -286,6 +218,207 @@ class _BaseAttentionMechanism(AttentionMechanism): return self.initial_alignments(batch_size, dtype) +class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer): + """A base AttentionMechanism class providing common functionality. + + Common functionality includes: + 1. Storing the query and memory layers. + 2. Preprocessing and storing the memory. + + Note that this layer only support Keras functional API since it takes multiple + input tensors, which is not available in sequential model. + """ + + def __init__(self, + probability_fn, + query_layer=None, + memory_layer=None, + **kwargs): + """Construct base AttentionMechanism class. + + Args: + probability_fn: A `callable`. Converts the score and previous alignments + to probabilities. Its signature should be: + `probabilities = probability_fn(score, state)`. + query_layer: (optional): Instance of `tf.keras.Layer`. The layer's depth + must match the depth of `memory_layer`. If `query_layer` is not + provided, the shape of `query` must match that of `memory_layer`. + memory_layer: (optional): Instance of `tf.keras.Layer`. The layer's + depth must match the depth of `query_layer`. + If `memory_layer` is not provided, the shape of `memory` must match + that of `query_layer`. + **kwargs: Dictionary that contains other common arguments for layer + creation. + """ + if (query_layer is not None + and not isinstance(query_layer, layers.Layer)): + raise TypeError( + "query_layer is not a Layer: %s" % type(query_layer).__name__) + if (memory_layer is not None + and not isinstance(memory_layer, layers.Layer)): + raise TypeError( + "memory_layer is not a Layer: %s" % type(memory_layer).__name__) + self.query_layer = query_layer + self.memory_layer = memory_layer + if self.memory_layer is not None and "dtype" not in kwargs: + kwargs["dtype"] = self.memory_layer.dtype + super(_BaseAttentionMechanismV2, self).__init__(**kwargs) + if not callable(probability_fn): + raise TypeError("probability_fn must be callable, saw type: %s" % + type(probability_fn).__name__) + self.probability_fn = probability_fn + + self.keys = None + self.values = None + self.batch_size = None + self._memory_initialized = False + self._check_inner_dims_defined = True + + def build(self, input_shape): + # The layer suppose to take 3 inputs, [query, state, memory]. + query_input_shape, _, memory_input_shape = input_shape + if self.query_layer is not None: + self.query_layer.build(query_input_shape) + if self.memory_layer is not None: + self.memory_layer.build(memory_input_shape) + # dtype of the layer is known at this moment, create the score_mask_value if + # needed. + self.score_mask_value = dtypes.as_dtype(self.dtype).as_numpy_dtype(-np.inf) + self.built = True + + def _setup_memory(self, memory, memory_mask=None): + """Pre-process the memory before actually query the memory. + + This should only be called once at the first invocation of call(). + + Args: + memory: The memory to query; usually the output of an RNN encoder. This + tensor should be shaped `[batch_size, max_time, ...]`. + memory_mask: The boolean tensor with shape `[batch_size, max_time]`. For + any value equal to False, the corresponding value in memory should be + ignored. + """ + if self._memory_initialized: + raise ValueError("The memory for the attention has already been setup.") + with ops.name_scope( + self.name, "BaseAttentionMechanismInit", nest.flatten(memory)): + self.values = _prepare_memory( + memory, memory_mask=memory_mask, + check_inner_dims_defined=self._check_inner_dims_defined) + if self.memory_layer is not None: + self.keys = self.memory_layer(self.values) + else: + self.keys = self.values + self.batch_size = ( + tensor_shape.dimension_value(self.keys.shape[0]) or + array_ops.shape(self.keys)[0]) + self._alignments_size = (tensor_shape.dimension_value(self.keys.shape[1]) + or array_ops.shape(self.keys)[1]) + if memory_mask is not None: + self.probability_fn = lambda score, prev: ( # pylint:disable=g-long-lambda + self.probability_fn(_maybe_mask_score( + score, self.score_mask_value, memory_mask=memory_mask), prev)) + self._memory_initialized = True + + def call(self, inputs, mask=None, **kwargs): + """Base method to calculate the attention score. + + Args: + inputs: a list of tensor that contains `query`, `state`, and `memory`. + `query` is the tensor of dtype matching `memory` and shape + `[batch_size, query_depth]`. + `state` is the tensor of dtype matching `memory` and shape + `[batch_size, alignments_size]`. (`alignments_size` is memory's + `max_time`). + `memory` is the memory to query; usually the output of an RNN encoder. + This tensor should be shaped `[batch_size, max_time, feature]`. + mask: optional bool tensor with shape `[batch, max_time]` for the mask of + memory. If it is not None, the corresponding item of the memory should + be filtered out during calculation. + **kwargs: Dict, other keyword arguments for the call method. + """ + query, state, memory, memory_mask = self._process_inputs(inputs, mask) + if not self._memory_initialized: + self._setup_memory(memory, memory_mask=memory_mask) + return self.calculate_attention(query, state) + + def calculate_attention(self, query, state): + raise NotImplementedError( + "calculate_attention need to be implemented by subclasses.") + + def get_config(self): + config = {} + # Since the probability_fn is likely to be a wrapped function, the child + # class should preserve the original function and how its wrapped. + + if self.query_layer is not None: + config["query_layer"] = { + "class_name": self.query_layer.__class__.__name__, + "config": self.query_layer.get_config(), + } + if self.memory_layer is not None: + config["memory_layer"] = { + "class_name": self.memory_layer.__class__.__name__, + "config": self.memory_layer.get_config(), + } + base_config = super(_BaseAttentionMechanismV2, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def _process_inputs(self, inputs, mask): + if len(inputs) != 3: + raise ValueError( + "Expect to have 3 inputs for attention, got %d" % len(inputs)) + query, state, memory = inputs + return query, state, memory, mask + + def _process_probability_fn(self, func_name): + """Helper method to retrieve the probably function by string input.""" + valid_probability_fns = { + "softmax": nn_ops.softmax, + "hardmax": hardmax, + } + if func_name not in valid_probability_fns.keys(): + raise ValueError("Invalid probability function: %s, options are %s" % + (func_name, valid_probability_fns.keys())) + return valid_probability_fns[func_name] + + @classmethod + def deserialize_inner_layer_from_config(cls, config, custom_objects): + """Helper method that reconstruct the query and memory from the config. + + In the get_config() method, the query and memory layer configs are + serialized into dict for persistence, this method perform the reverse action + to reconstruct the layer from the config. + + Args: + config: dict, the configs that will be used to reconstruct the object. + custom_objects: dict mapping class names (or function names) of custom + (non-Keras) objects to class/functions. + Returns: + config: dict, the config with layer instance created, which is ready to be + used as init parameters. + """ + # Reconstruct the query and memory layer for parent class. + from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top + # Instead of updating the input, create a copy and use that. + config = config.copy() + query_layer_config = config.pop("query_layer", None) + if query_layer_config: + query_layer = deserialize_layer(query_layer_config, + custom_objects=custom_objects) + config["query_layer"] = query_layer + memory_layer_config = config.pop("memory_layer", None) + if memory_layer_config: + memory_layer = deserialize_layer(memory_layer_config, + custom_objects=custom_objects) + config["memory_layer"] = memory_layer + return config + + @property + def alignments_size(self): + return self._alignments_size + + def _luong_score(query, keys, scale): """Implements Luong-style (multiplicative) scoring function. @@ -304,7 +437,7 @@ def _luong_score(query, keys, scale): Args: query: Tensor, shape `[batch_size, num_units]` to compare to keys. keys: Processed memory, shape `[batch_size, max_time, num_units]`. - scale: Whether to apply a scale to the score function. + scale: the optional tensor to scale the attention score. Returns: A `[batch_size, max_time]` tensor of unnormalized score values. @@ -320,7 +453,6 @@ def _luong_score(query, keys, scale): "Query (%s) has units: %s. Keys (%s) have units: %s. " "Perhaps you need to set num_units to the keys' dimension (%s)?" % (query, depth, keys, key_units, key_units)) - dtype = query.dtype # Reshape from [batch_size, depth] to [batch_size, 1, depth] # for matmul. @@ -338,12 +470,8 @@ def _luong_score(query, keys, scale): score = math_ops.matmul(query, keys, transpose_b=True) score = array_ops.squeeze(score, [1]) - if scale: - # Scalar used in weight scaling - g = variable_scope.get_variable( - "attention_g", dtype=dtype, - initializer=init_ops.ones_initializer, shape=()) - score = g * score + if scale is not None: + score = scale * score return score @@ -354,8 +482,8 @@ class LuongAttention(_BaseAttentionMechanism): as described in: Minh-Thang Luong, Hieu Pham, Christopher D. Manning. - "Effective Approaches to Attention-based Neural Machine Translation." - EMNLP 2015. https://arxiv.org/abs/1508.04025 + [Effective Approaches to Attention-based Neural Machine Translation. + EMNLP 2015.](https://arxiv.org/abs/1508.04025) The second is the scaled form inspired partly by the normalized form of Bahdanau attention. @@ -429,13 +557,125 @@ class LuongAttention(_BaseAttentionMechanism): `max_time`). """ with variable_scope.variable_scope(None, "luong_attention", [query]): - score = _luong_score(query, self._keys, self._scale) + attention_g = None + if self._scale: + attention_g = variable_scope.get_variable( + "attention_g", dtype=query.dtype, + initializer=init_ops.ones_initializer, shape=()) + score = _luong_score(query, self._keys, attention_g) alignments = self._probability_fn(score, state) next_state = alignments return alignments, next_state -def _bahdanau_score(processed_query, keys, normalize): +class LuongAttentionV2(_BaseAttentionMechanismV2): + """Implements Luong-style (multiplicative) attention scoring. + + This attention has two forms. The first is standard Luong attention, + as described in: + + Minh-Thang Luong, Hieu Pham, Christopher D. Manning. + [Effective Approaches to Attention-based Neural Machine Translation. + EMNLP 2015.](https://arxiv.org/abs/1508.04025) + + The second is the scaled form inspired partly by the normalized form of + Bahdanau attention. + + To enable the second form, construct the object with parameter + `scale=True`. + """ + + def __init__(self, + units, + scale=False, + probability_fn="softmax", + dtype=None, + name="LuongAttention", + **kwargs): + """Construct the AttentionMechanism mechanism. + + Args: + units: The depth of the attention mechanism. + scale: Python boolean. Whether to scale the energy term. + probability_fn: (optional) string, the name of function to convert the + attention score to probabilities. The default is `softmax` which is + `tf.nn.softmax`. Other options is `hardmax`, which is hardmax() within + this module. Any other value will result intovalidation error. Default + to use `softmax`. + dtype: The data type for the memory layer of the attention mechanism. + name: Name to use when creating ops. + **kwargs: Dictionary that contains other common arguments for layer + creation. + """ + # For LuongAttention, we only transform the memory layer; thus + # num_units **must** match expected the query depth. + self.probability_fn_name = probability_fn + probability_fn = self._process_probability_fn(self.probability_fn_name) + wrapped_probability_fn = lambda score, _: probability_fn(score) + if dtype is None: + dtype = dtypes.float32 + memory_layer = kwargs.pop("memory_layer", None) + if not memory_layer: + memory_layer = layers.Dense( + units, name="memory_layer", use_bias=False, dtype=dtype) + super(LuongAttentionV2, self).__init__( + query_layer=None, + memory_layer=memory_layer, + probability_fn=wrapped_probability_fn, + name=name, + dtype=dtype, + **kwargs) + self.units = units + self.scale = scale + + def build(self, input_shape): + super(LuongAttentionV2, self).build(input_shape) + if self.scale: + self.scale_weight = self.add_weight( + "attention_g", initializer=init_ops.ones_initializer, shape=()) + else: + self.scale_weight = None + self.built = True + + def calculate_attention(self, query, state): + """Score the query based on the keys and values. + + Args: + query: Tensor of dtype matching `self.values` and shape + `[batch_size, query_depth]`. + state: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` + (`alignments_size` is memory's `max_time`). + + Returns: + alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` (`alignments_size` is memory's + `max_time`). + next_state: Same as the alignments. + """ + score = _luong_score(query, self.keys, self.scale_weight) + alignments = self.probability_fn(score, state) + next_state = alignments + return alignments, next_state + + def get_config(self): + config = { + "units": self.units, + "scale": self.scale, + "probability_fn": self.probability_fn_name, + } + base_config = super(LuongAttentionV2, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config( + config, custom_objects=custom_objects) + return cls(**config) + + +def _bahdanau_score(processed_query, keys, attention_v, + attention_g=None, attention_b=None): """Implements Bahdanau-style (additive) scoring function. This attention has two forms. The first is Bhandanau attention, @@ -453,41 +693,28 @@ def _bahdanau_score(processed_query, keys, normalize): Training of Deep Neural Networks." https://arxiv.org/abs/1602.07868 - To enable the second form, set `normalize=True`. + To enable the second form, set please pass in attention_g and attention_b. Args: processed_query: Tensor, shape `[batch_size, num_units]` to compare to keys. keys: Processed memory, shape `[batch_size, max_time, num_units]`. - normalize: Whether to normalize the score function. + attention_v: Tensor, shape `[num_units]`. + attention_g: Optional scalar tensor for normalization. + attention_b: Optional tensor with shape `[num_units]` for normalization. Returns: A `[batch_size, max_time]` tensor of unnormalized score values. """ - dtype = processed_query.dtype - # Get the number of hidden units from the trailing dimension of keys - num_units = tensor_shape.dimension_value( - keys.shape[2]) or array_ops.shape(keys)[2] # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting. processed_query = array_ops.expand_dims(processed_query, 1) - v = variable_scope.get_variable( - "attention_v", [num_units], dtype=dtype) - if normalize: - # Scalar used in weight normalization - g = variable_scope.get_variable( - "attention_g", dtype=dtype, - initializer=init_ops.constant_initializer(math.sqrt((1. / num_units))), - shape=()) - # Bias added prior to the nonlinearity - b = variable_scope.get_variable( - "attention_b", [num_units], dtype=dtype, - initializer=init_ops.zeros_initializer()) - # normed_v = g * v / ||v|| - normed_v = g * v * math_ops.rsqrt( - math_ops.reduce_sum(math_ops.square(v))) + if attention_g is not None and attention_b is not None: + normed_v = attention_g * attention_v * math_ops.rsqrt( + math_ops.reduce_sum(math_ops.square(attention_v))) return math_ops.reduce_sum( - normed_v * math_ops.tanh(keys + processed_query + b), [2]) + normed_v * math_ops.tanh(keys + processed_query + attention_b), [2]) else: - return math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query), [2]) + return math_ops.reduce_sum( + attention_v * math_ops.tanh(keys + processed_query), [2]) class BahdanauAttention(_BaseAttentionMechanism): @@ -578,12 +805,152 @@ class BahdanauAttention(_BaseAttentionMechanism): """ with variable_scope.variable_scope(None, "bahdanau_attention", [query]): processed_query = self.query_layer(query) if self.query_layer else query - score = _bahdanau_score(processed_query, self._keys, self._normalize) + attention_v = variable_scope.get_variable( + "attention_v", [self._num_units], dtype=query.dtype) + if not self._normalize: + attention_g = None + attention_b = None + else: + attention_g = variable_scope.get_variable( + "attention_g", dtype=query.dtype, + initializer=init_ops.constant_initializer( + math.sqrt((1. / self._num_units))), + shape=()) + attention_b = variable_scope.get_variable( + "attention_b", [self._num_units], dtype=query.dtype, + initializer=init_ops.zeros_initializer()) + + score = _bahdanau_score(processed_query, self._keys, attention_v, + attention_g=attention_g, attention_b=attention_b) alignments = self._probability_fn(score, state) next_state = alignments return alignments, next_state +class BahdanauAttentionV2(_BaseAttentionMechanismV2): + """Implements Bahdanau-style (additive) attention. + + This attention has two forms. The first is Bahdanau attention, + as described in: + + Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. + "Neural Machine Translation by Jointly Learning to Align and Translate." + ICLR 2015. https://arxiv.org/abs/1409.0473 + + The second is the normalized form. This form is inspired by the + weight normalization article: + + Tim Salimans, Diederik P. Kingma. + "Weight Normalization: A Simple Reparameterization to Accelerate + Training of Deep Neural Networks." + https://arxiv.org/abs/1602.07868 + + To enable the second form, construct the object with parameter + `normalize=True`. + """ + + def __init__(self, + units, + normalize=False, + probability_fn="softmax", + dtype=None, + name="BahdanauAttention", + **kwargs): + """Construct the Attention mechanism. + + Args: + units: The depth of the query mechanism. + normalize: Python boolean. Whether to normalize the energy term. + probability_fn: (optional) string, the name of function to convert the + attention score to probabilities. The default is `softmax` which is + `tf.nn.softmax`. Other options is `hardmax`, which is hardmax() within + this module. Any other value will result into validation error. Default + to use `softmax`. + dtype: The data type for the query and memory layers of the attention + mechanism. + name: Name to use when creating ops. + **kwargs: Dictionary that contains other common arguments for layer + creation. + """ + self.probability_fn_name = probability_fn + probability_fn = self._process_probability_fn(self.probability_fn_name) + wrapped_probability_fn = lambda score, _: probability_fn(score) + if dtype is None: + dtype = dtypes.float32 + query_layer = kwargs.pop("query_layer", None) + if not query_layer: + query_layer = layers.Dense( + units, name="query_layer", use_bias=False, dtype=dtype) + memory_layer = kwargs.pop("memory_layer", None) + if not memory_layer: + memory_layer = layers.Dense( + units, name="memory_layer", use_bias=False, dtype=dtype) + super(BahdanauAttentionV2, self).__init__( + query_layer=query_layer, + memory_layer=memory_layer, + probability_fn=wrapped_probability_fn, + name=name, + dtype=dtype, + **kwargs) + self.units = units + self.normalize = normalize + + def build(self, input_shape): + super(BahdanauAttentionV2, self).build(input_shape) + self.attention_v = self.add_weight( + "attention_v", [self.units], dtype=self.dtype) + if self.normalize: + self.attention_g = self.add_weight( + "attention_g", initializer=init_ops.constant_initializer( + math.sqrt((1. / self.units))), shape=()) + self.attention_b = self.add_weight( + "attention_b", shape=[self.units], + initializer=init_ops.zeros_initializer()) + else: + self.attention_g = None + self.attention_b = None + self.built = True + + def calculate_attention(self, query, state): + """Score the query based on the keys and values. + + Args: + query: Tensor of dtype matching `self.values` and shape + `[batch_size, query_depth]`. + state: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` + (`alignments_size` is memory's `max_time`). + + Returns: + alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` (`alignments_size` is memory's + `max_time`). + next_state: same as alignments. + """ + processed_query = self.query_layer(query) if self.query_layer else query + score = _bahdanau_score(processed_query, self.keys, self.attention_v, + attention_g=self.attention_g, + attention_b=self.attention_b) + alignments = self.probability_fn(score, state) + next_state = alignments + return alignments, next_state + + def get_config(self): + config = { + "units": self.units, + "normalize": self.normalize, + "probability_fn": self.probability_fn_name, + } + base_config = super(BahdanauAttentionV2, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config( + config, custom_objects=custom_objects) + return cls(**config) + + def safe_cumprod(x, *args, **kwargs): """Computes cumprod of x in logspace using cumsum to avoid underflow. @@ -766,6 +1133,34 @@ class _BaseMonotonicAttentionMechanism(_BaseAttentionMechanism): dtype=dtype) +class _BaseMonotonicAttentionMechanismV2(_BaseAttentionMechanismV2): + """Base attention mechanism for monotonic attention. + + Simply overrides the initial_alignments function to provide a dirac + distribution, which is needed in order for the monotonic attention + distributions to have the correct behavior. + """ + + def initial_alignments(self, batch_size, dtype): + """Creates the initial alignment values for the monotonic attentions. + + Initializes to dirac distributions, i.e. [1, 0, 0, ...memory length..., 0] + for all entries in the batch. + + Args: + batch_size: `int32` scalar, the batch_size. + dtype: The `dtype`. + + Returns: + A `dtype` tensor shaped `[batch_size, alignments_size]` + (`alignments_size` is the values' `max_time`). + """ + max_time = self._alignments_size + return array_ops.one_hot( + array_ops.zeros((batch_size,), dtype=dtypes.int32), max_time, + dtype=dtype) + + class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): """Monotonic attention mechanism with Bahadanau-style energy function. @@ -860,7 +1255,22 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): with variable_scope.variable_scope( None, "bahdanau_monotonic_attention", [query]): processed_query = self.query_layer(query) if self.query_layer else query - score = _bahdanau_score(processed_query, self._keys, self._normalize) + attention_v = variable_scope.get_variable( + "attention_v", [self._num_units], dtype=query.dtype) + if not self._normalize: + attention_g = None + attention_b = None + else: + attention_g = variable_scope.get_variable( + "attention_g", dtype=query.dtype, + initializer=init_ops.constant_initializer( + math.sqrt((1. / self._num_units))), + shape=()) + attention_b = variable_scope.get_variable( + "attention_b", [self._num_units], dtype=query.dtype, + initializer=init_ops.zeros_initializer()) + score = _bahdanau_score(processed_query, self._keys, attention_v, + attention_g=attention_g, attention_b=attention_b) score_bias = variable_scope.get_variable( "attention_score_bias", dtype=processed_query.dtype, initializer=self._score_bias_init) @@ -870,6 +1280,146 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): return alignments, next_state +class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): + """Monotonic attention mechanism with Bahadanau-style energy function. + + This type of attention enforces a monotonic constraint on the attention + distributions; that is once the model attends to a given point in the memory + it can't attend to any prior points at subsequence output timesteps. It + achieves this by using the _monotonic_probability_fn instead of softmax to + construct its attention distributions. Since the attention scores are passed + through a sigmoid, a learnable scalar bias parameter is applied after the + score function and before the sigmoid. Otherwise, it is equivalent to + BahdanauAttention. This approach is proposed in + + Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, + "Online and Linear-Time Attention by Enforcing Monotonic Alignments." + ICML 2017. https://arxiv.org/abs/1704.00784 + """ + + def __init__(self, + units, + normalize=False, + sigmoid_noise=0., + sigmoid_noise_seed=None, + score_bias_init=0., + mode="parallel", + dtype=None, + name="BahdanauMonotonicAttention", + **kwargs): + """Construct the Attention mechanism. + + Args: + units: The depth of the query mechanism. + normalize: Python boolean. Whether to normalize the energy term. + sigmoid_noise: Standard deviation of pre-sigmoid noise. See the docstring + for `_monotonic_probability_fn` for more information. + sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise. + score_bias_init: Initial value for score bias scalar. It's recommended to + initialize this to a negative value when the length of the memory is + large. + mode: How to compute the attention distribution. Must be one of + 'recursive', 'parallel', or 'hard'. See the docstring for + `tf.contrib.seq2seq.monotonic_attention` for more information. + dtype: The data type for the query and memory layers of the attention + mechanism. + name: Name to use when creating ops. + **kwargs: Dictionary that contains other common arguments for layer + creation. + """ + # Set up the monotonic probability fn with supplied parameters + if dtype is None: + dtype = dtypes.float32 + wrapped_probability_fn = functools.partial( + _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, + seed=sigmoid_noise_seed) + query_layer = kwargs.pop("query_layer", None) + if not query_layer: + query_layer = layers.Dense( + units, name="query_layer", use_bias=False, dtype=dtype) + memory_layer = kwargs.pop("memory_layer", None) + if not memory_layer: + memory_layer = layers.Dense( + units, name="memory_layer", use_bias=False, dtype=dtype) + super(BahdanauMonotonicAttentionV2, self).__init__( + query_layer=query_layer, + memory_layer=memory_layer, + probability_fn=wrapped_probability_fn, + name=name, + dtype=dtype, + **kwargs) + self.units = units + self.normalize = normalize + self.sigmoid_noise = sigmoid_noise + self.sigmoid_noise_seed = sigmoid_noise_seed + self.score_bias_init = score_bias_init + self.mode = mode + + def build(self, input_shape): + super(BahdanauMonotonicAttentionV2, self).build(input_shape) + self.attention_v = self.add_weight( + "attention_v", [self.units], dtype=self.dtype) + self.attention_score_bias = self.add_weight( + "attention_score_bias", shape=(), dtype=self.dtype, + initializer=init_ops.constant_initializer( + self.score_bias_init, dtype=self.dtype)) + if not self.normalize: + self.attention_g = None + self.attention_b = None + else: + self.attention_g = self.add_weight( + "attention_g", dtype=self.dtype, + initializer=init_ops.constant_initializer( + math.sqrt((1. / self.units))), + shape=()) + self.attention_b = self.add_weight( + "attention_b", [self.units], dtype=self.dtype, + initializer=init_ops.zeros_initializer()) + self.built = True + + def calculate_attention(self, query, state): + """Score the query based on the keys and values. + + Args: + query: Tensor of dtype matching `self.values` and shape + `[batch_size, query_depth]`. + state: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` + (`alignments_size` is memory's `max_time`). + + Returns: + alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` (`alignments_size` is memory's + `max_time`). + """ + processed_query = self.query_layer(query) if self.query_layer else query + score = _bahdanau_score(processed_query, self.keys, self.attention_v, + attention_g=self.attention_g, + attention_b=self.attention_b) + score += self.attention_score_bias + alignments = self.probability_fn(score, state) + next_state = alignments + return alignments, next_state + + def get_config(self): + config = { + "units": self.units, + "normalize": self.normalize, + "sigmoid_noise": self.sigmoid_noise, + "sigmoid_noise_seed": self.sigmoid_noise_seed, + "score_bias_init": self.score_bias_init, + "mode": self.mode, + } + base_config = super(BahdanauMonotonicAttentionV2, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config( + config, custom_objects=custom_objects) + return cls(**config) + + class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): """Monotonic attention mechanism with Luong-style energy function. @@ -960,7 +1510,12 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): """ with variable_scope.variable_scope(None, "luong_monotonic_attention", [query]): - score = _luong_score(query, self._keys, self._scale) + attention_g = None + if self._scale: + attention_g = variable_scope.get_variable( + "attention_g", dtype=query.dtype, + initializer=init_ops.ones_initializer, shape=()) + score = _luong_score(query, self._keys, attention_g) score_bias = variable_scope.get_variable( "attention_score_bias", dtype=query.dtype, initializer=self._score_bias_init) @@ -970,6 +1525,129 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): return alignments, next_state +class LuongMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): + """Monotonic attention mechanism with Luong-style energy function. + + This type of attention enforces a monotonic constraint on the attention + distributions; that is once the model attends to a given point in the memory + it can't attend to any prior points at subsequence output timesteps. It + achieves this by using the _monotonic_probability_fn instead of softmax to + construct its attention distributions. Otherwise, it is equivalent to + LuongAttention. This approach is proposed in + + [Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, + "Online and Linear-Time Attention by Enforcing Monotonic Alignments." + ICML 2017.](https://arxiv.org/abs/1704.00784) + """ + + def __init__(self, + units, + scale=False, + sigmoid_noise=0., + sigmoid_noise_seed=None, + score_bias_init=0., + mode="parallel", + dtype=None, + name="LuongMonotonicAttention", + **kwargs): + """Construct the Attention mechanism. + + Args: + units: The depth of the query mechanism. + scale: Python boolean. Whether to scale the energy term. + sigmoid_noise: Standard deviation of pre-sigmoid noise. See the docstring + for `_monotonic_probability_fn` for more information. + sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise. + score_bias_init: Initial value for score bias scalar. It's recommended to + initialize this to a negative value when the length of the memory is + large. + mode: How to compute the attention distribution. Must be one of + 'recursive', 'parallel', or 'hard'. See the docstring for + `tf.contrib.seq2seq.monotonic_attention` for more information. + dtype: The data type for the query and memory layers of the attention + mechanism. + name: Name to use when creating ops. + **kwargs: Dictionary that contains other common arguments for layer + creation. + """ + # Set up the monotonic probability fn with supplied parameters + if dtype is None: + dtype = dtypes.float32 + wrapped_probability_fn = functools.partial( + _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, + seed=sigmoid_noise_seed) + memory_layer = kwargs.pop("memory_layer", None) + if not memory_layer: + memory_layer = layers.Dense( + units, name="memory_layer", use_bias=False, dtype=dtype) + super(LuongMonotonicAttentionV2, self).__init__( + query_layer=None, + memory_layer=memory_layer, + probability_fn=wrapped_probability_fn, + name=name, + dtype=dtype, + **kwargs) + self.units = units + self.scale = scale + self.sigmoid_noise = sigmoid_noise + self.sigmoid_noise_seed = sigmoid_noise_seed + self.score_bias_init = score_bias_init + self.mode = mode + + def build(self, input_shape): + super(LuongMonotonicAttentionV2, self).build(input_shape) + if self.scale: + self.attention_g = self.add_weight( + "attention_g", initializer=init_ops.ones_initializer, shape=()) + else: + self.attention_g = None + self.attention_score_bias = self.add_weight( + "attention_score_bias", shape=(), + initializer=init_ops.constant_initializer( + self.score_bias_init, dtype=self.dtype)) + self.built = True + + def calculate_attention(self, query, state): + """Score the query based on the keys and values. + + Args: + query: Tensor of dtype matching `self.values` and shape + `[batch_size, query_depth]`. + state: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` + (`alignments_size` is memory's `max_time`). + + Returns: + alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` (`alignments_size` is memory's + `max_time`). + next_state: Same as alignments + """ + score = _luong_score(query, self.keys, self.attention_g) + score += self.attention_score_bias + alignments = self.probability_fn(score, state) + next_state = alignments + return alignments, next_state + + def get_config(self): + config = { + "units": self.units, + "scale": self.scale, + "sigmoid_noise": self.sigmoid_noise, + "sigmoid_noise_seed": self.sigmoid_noise_seed, + "score_bias_init": self.score_bias_init, + "mode": self.mode, + } + base_config = super(LuongMonotonicAttentionV2, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config( + config, custom_objects=custom_objects) + return cls(**config) + + class AttentionWrapperState( collections.namedtuple("AttentionWrapperState", ("cell_state", "attention", "time", "alignments", @@ -1026,6 +1704,97 @@ class AttentionWrapperState( super(AttentionWrapperState, self)._replace(**kwargs)) +def _prepare_memory(memory, memory_sequence_length=None, memory_mask=None, + check_inner_dims_defined=True): + """Convert to tensor and possibly mask `memory`. + + Args: + memory: `Tensor`, shaped `[batch_size, max_time, ...]`. + memory_sequence_length: `int32` `Tensor`, shaped `[batch_size]`. + memory_mask: `boolean` tensor with shape [batch_size, max_time]. The memory + should be skipped when the corresponding mask is False. + check_inner_dims_defined: Python boolean. If `True`, the `memory` + argument's shape is checked to ensure all but the two outermost + dimensions are fully defined. + + Returns: + A (possibly masked), checked, new `memory`. + + Raises: + ValueError: If `check_inner_dims_defined` is `True` and not + `memory.shape[2:].is_fully_defined()`. + """ + memory = nest.map_structure( + lambda m: ops.convert_to_tensor(m, name="memory"), memory) + if memory_sequence_length is not None and memory_mask is not None: + raise ValueError("memory_sequence_length and memory_mask can't be provided " + "at same time.") + if memory_sequence_length is not None: + memory_sequence_length = ops.convert_to_tensor( + memory_sequence_length, name="memory_sequence_length") + if check_inner_dims_defined: + def _check_dims(m): + if not m.get_shape()[2:].is_fully_defined(): + raise ValueError("Expected memory %s to have fully defined inner dims, " + "but saw shape: %s" % (m.name, m.get_shape())) + nest.map_structure(_check_dims, memory) + if memory_sequence_length is None and memory_mask is None: + seq_len_mask = None + seq_len_batch_size = None + elif memory_sequence_length is not None: + seq_len_mask = array_ops.sequence_mask( + memory_sequence_length, + maxlen=array_ops.shape(nest.flatten(memory)[0])[1], + dtype=nest.flatten(memory)[0].dtype) + seq_len_batch_size = ( + tensor_shape.dimension_value(memory_sequence_length.shape[0]) + or array_ops.shape(memory_sequence_length)[0]) + else: + # For memory_mask is not None + seq_len_mask = memory_mask + seq_len_batch_size = ( + tensor_shape.dimension_value(memory_mask.shape[0]) + or array_ops.shape(memory_mask)[0]) + def _maybe_mask(m, seq_len_mask): + """Mask the memory based on the memory mask.""" + rank = m.get_shape().ndims + rank = rank if rank is not None else array_ops.rank(m) + extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32) + m_batch_size = tensor_shape.dimension_value( + m.shape[0]) or array_ops.shape(m)[0] + if seq_len_batch_size is not None: + message = ("memory_sequence_length and memory tensor batch sizes do not " + "match.") + with ops.control_dependencies([ + check_ops.assert_equal( + seq_len_batch_size, m_batch_size, message=message)]): + seq_len_mask = array_ops.reshape( + seq_len_mask, + array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0)) + return m * seq_len_mask + else: + return m + return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory) + + +def _maybe_mask_score(score, memory_sequence_length=None, memory_mask=None, + score_mask_value=None): + """Mask the attention score based on the masks.""" + if memory_sequence_length is None and memory_mask is None: + return score + if memory_sequence_length is not None and memory_mask is not None: + raise ValueError("memory_sequence_length and memory_mask can't be provided " + "at same time.") + if memory_sequence_length is not None: + message = "All values in memory_sequence_length must greater than zero." + with ops.control_dependencies( + [check_ops.assert_positive(memory_sequence_length, message=message)]): + memory_mask = array_ops.sequence_mask( + memory_sequence_length, maxlen=array_ops.shape(score)[1]) + score_mask_values = score_mask_value * array_ops.ones_like(score) + return array_ops.where(memory_mask, score, score_mask_values) + + def hardmax(logits, name=None): """Returns batched one-hot vectors. diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py index 3245cc5e72154289ea3ba000b9a30586a7ad03a9..033c2eb0801d5a51ee937f5e960faa91a6f1ae54 100644 --- a/tensorflow/contrib/seq2seq/python/ops/helper.py +++ b/tensorflow/contrib/seq2seq/python/ops/helper.py @@ -32,9 +32,8 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import tensor_array_ops -from tensorflow.python.ops.distributions import bernoulli -from tensorflow.python.ops.distributions import categorical from tensorflow.python.util import nest __all__ = [ @@ -51,6 +50,68 @@ __all__ = [ _transpose_batch_time = decoder._transpose_batch_time # pylint: disable=protected-access +# The following sample functions (_call_sampler, bernoulli_sample, +# categorical_sample) mimic TensorFlow Probability distribution semantics. + + +def _call_sampler(sample_n_fn, sample_shape, name=None): + """Reshapes vector of samples.""" + with ops.name_scope(name, "call_sampler", values=[sample_shape]): + sample_shape = ops.convert_to_tensor( + sample_shape, dtype=dtypes.int32, name="sample_shape") + # Ensure sample_shape is a vector (vs just a scalar). + pad = math_ops.cast(math_ops.equal(array_ops.rank(sample_shape), 0), + dtypes.int32) + sample_shape = array_ops.reshape( + sample_shape, + array_ops.pad(array_ops.shape(sample_shape), + paddings=[[pad, 0]], + constant_values=1)) + samples = sample_n_fn(math_ops.reduce_prod(sample_shape)) + batch_event_shape = array_ops.shape(samples)[1:] + final_shape = array_ops.concat([sample_shape, batch_event_shape], 0) + return array_ops.reshape(samples, final_shape) + + +def bernoulli_sample(probs=None, logits=None, dtype=dtypes.int32, + sample_shape=(), seed=None): + """Samples from Bernoulli distribution.""" + if probs is None: + probs = math_ops.sigmoid(logits, name="probs") + else: + probs = ops.convert_to_tensor(probs, name="probs") + batch_shape_tensor = array_ops.shape(probs) + def _sample_n(n): + """Sample vector of Bernoullis.""" + new_shape = array_ops.concat([[n], batch_shape_tensor], 0) + uniform = random_ops.random_uniform( + new_shape, seed=seed, dtype=probs.dtype) + return math_ops.cast(math_ops.less(uniform, probs), dtype) + return _call_sampler(_sample_n, sample_shape) + + +def categorical_sample(logits, dtype=dtypes.int32, + sample_shape=(), seed=None): + """Samples from categorical distribution.""" + logits = ops.convert_to_tensor(logits, name="logits") + event_size = array_ops.shape(logits)[-1] + batch_shape_tensor = array_ops.shape(logits)[:-1] + def _sample_n(n): + """Sample vector of categoricals.""" + if logits.shape.ndims == 2: + logits_2d = logits + else: + logits_2d = array_ops.reshape(logits, [-1, event_size]) + sample_dtype = dtypes.int64 if logits.dtype.size > 4 else dtypes.int32 + draws = random_ops.multinomial( + logits_2d, n, seed=seed, output_dtype=sample_dtype) + draws = array_ops.reshape( + array_ops.transpose(draws), + array_ops.concat([[n], batch_shape_tensor], 0)) + return math_ops.cast(draws, dtype) + return _call_sampler(_sample_n, sample_shape) + + def _unstack_ta(inp): return tensor_array_ops.TensorArray( dtype=inp.dtype, size=array_ops.shape(inp)[0], @@ -307,14 +368,14 @@ class ScheduledEmbeddingTrainingHelper(TrainingHelper): with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample", [time, outputs, state]): # Return -1s where we did not sample, and sample_ids elsewhere - select_sampler = bernoulli.Bernoulli( - probs=self._sampling_probability, dtype=dtypes.bool) - select_sample = select_sampler.sample( - sample_shape=self.batch_size, seed=self._scheduling_seed) - sample_id_sampler = categorical.Categorical(logits=outputs) + select_sample = bernoulli_sample( + probs=self._sampling_probability, + dtype=dtypes.bool, + sample_shape=self.batch_size, + seed=self._scheduling_seed) return array_ops.where( select_sample, - sample_id_sampler.sample(seed=self._seed), + categorical_sample(logits=outputs, seed=self._seed), gen_array_ops.fill([self.batch_size], -1)) def next_inputs(self, time, outputs, state, sample_ids, name=None): @@ -425,8 +486,10 @@ class ScheduledOutputTrainingHelper(TrainingHelper): def sample(self, time, outputs, state, name=None): with ops.name_scope(name, "ScheduledOutputTrainingHelperSample", [time, outputs, state]): - sampler = bernoulli.Bernoulli(probs=self._sampling_probability) - return sampler.sample(sample_shape=self.batch_size, seed=self._seed) + return bernoulli_sample( + probs=self._sampling_probability, + sample_shape=self.batch_size, + seed=self._seed) def next_inputs(self, time, outputs, state, sample_ids, name=None): with ops.name_scope(name, "ScheduledOutputTrainingHelperNextInputs", @@ -610,8 +673,7 @@ class SampleEmbeddingHelper(GreedyEmbeddingHelper): else: logits = outputs / self._softmax_temperature - sample_id_sampler = categorical.Categorical(logits=logits) - sample_ids = sample_id_sampler.sample(seed=self._seed) + sample_ids = categorical_sample(logits=logits, seed=self._seed) return sample_ids diff --git a/tensorflow/contrib/seq2seq/python/ops/loss.py b/tensorflow/contrib/seq2seq/python/ops/loss.py index 39a6d2f58b140706a94d83273d3327edd1891368..0fbfd6187030f14ac105a18b3e09b7a42d4de32a 100644 --- a/tensorflow/contrib/seq2seq/python/ops/loss.py +++ b/tensorflow/contrib/seq2seq/python/ops/loss.py @@ -20,11 +20,12 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops +from tensorflow.python.keras.losses import Loss from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops -__all__ = ["sequence_loss"] +__all__ = ["sequence_loss", "SequenceLoss"] def sequence_loss(logits, @@ -32,16 +33,26 @@ def sequence_loss(logits, weights, average_across_timesteps=True, average_across_batch=True, + sum_over_timesteps=False, + sum_over_batch=False, softmax_loss_function=None, name=None): """Weighted cross-entropy loss for a sequence of logits. - Depending on the values of `average_across_timesteps` and - `average_across_batch`, the return Tensor will have rank 0, 1, or 2 as these - arguments reduce the cross-entropy at each target, which has shape - `[batch_size, sequence_length]`, over their respective dimensions. For - example, if `average_across_timesteps` is `True` and `average_across_batch` - is `False`, then the return Tensor will have shape `[batch_size]`. + Depending on the values of `average_across_timesteps` / `sum_over_timesteps` + and `average_across_batch` / `sum_over_batch`, the return Tensor will have + rank 0, 1, or 2 as these arguments reduce the cross-entropy at each target, + which has shape `[batch_size, sequence_length]`, over their respective + dimensions. For example, if `average_across_timesteps` is `True` and + `average_across_batch` is `False`, then the return Tensor will have shape + `[batch_size]`. + + Note that `average_across_timesteps` and `sum_over_timesteps` cannot be True + at same time. Same for `average_across_batch` and `sum_over_batch`. + + The recommended loss reduction in tf 2.0 has been changed to sum_over, instead + of weighted average. User are recommend to use `sum_over_timesteps` and + `sum_over_batch` for reduction. Args: logits: A Tensor of shape @@ -58,6 +69,12 @@ def sequence_loss(logits, dimension and divide the cost by the total label weight across timesteps. average_across_batch: If set, sum the cost across the batch dimension and divide the returned cost by the batch size. + sum_over_timesteps: If set, sum the cost across the sequence dimension and + divide the size of the sequence. Note that any element with 0 weights will + be excluded from size calculation. + sum_over_batch: if set, sum the cost across the batch dimension and divide + the total cost by the batch size. Not that any element with 0 weights will + be excluded from size calculation. softmax_loss_function: Function (labels, logits) -> loss-batch to be used instead of the standard softmax (the default if this is None). **Note that to avoid confusion, it is required for the function to accept @@ -78,11 +95,15 @@ def sequence_loss(logits, raise ValueError("Logits must be a " "[batch_size x sequence_length x logits] tensor") if len(targets.get_shape()) != 2: - raise ValueError("Targets must be a [batch_size x sequence_length] " - "tensor") + raise ValueError("Targets must be a [batch_size x sequence_length] tensor") if len(weights.get_shape()) != 2: - raise ValueError("Weights must be a [batch_size x sequence_length] " - "tensor") + raise ValueError("Weights must be a [batch_size x sequence_length] tensor") + if average_across_timesteps and sum_over_timesteps: + raise ValueError("average_across_timesteps and sum_over_timesteps cannot " + "be set to True at same time.") + if average_across_batch and sum_over_batch: + raise ValueError("average_across_batch and sum_over_batch cannot be set " + "to True at same time.") with ops.name_scope(name, "sequence_loss", [logits, targets, weights]): num_classes = array_ops.shape(logits)[2] logits_flat = array_ops.reshape(logits, [-1, num_classes]) @@ -96,20 +117,56 @@ def sequence_loss(logits, if average_across_timesteps and average_across_batch: crossent = math_ops.reduce_sum(crossent) total_size = math_ops.reduce_sum(weights) - total_size += 1e-12 # to avoid division by 0 for all-0 weights - crossent /= total_size + crossent = math_ops.div_no_nan(crossent, total_size) + elif sum_over_timesteps and sum_over_batch: + crossent = math_ops.reduce_sum(crossent) + total_count = math_ops.cast(math_ops.count_nonzero(weights), + crossent.dtype) + crossent = math_ops.div_no_nan(crossent, total_count) else: - batch_size = array_ops.shape(logits)[0] - sequence_length = array_ops.shape(logits)[1] - crossent = array_ops.reshape(crossent, [batch_size, sequence_length]) - if average_across_timesteps and not average_across_batch: - crossent = math_ops.reduce_sum(crossent, axis=[1]) - total_size = math_ops.reduce_sum(weights, axis=[1]) - total_size += 1e-12 # to avoid division by 0 for all-0 weights - crossent /= total_size - if not average_across_timesteps and average_across_batch: - crossent = math_ops.reduce_sum(crossent, axis=[0]) - total_size = math_ops.reduce_sum(weights, axis=[0]) - total_size += 1e-12 # to avoid division by 0 for all-0 weights - crossent /= total_size + crossent = array_ops.reshape(crossent, array_ops.shape(logits)[0:2]) + if average_across_timesteps or average_across_batch: + reduce_axis = [0] if average_across_batch else [1] + crossent = math_ops.reduce_sum(crossent, axis=reduce_axis) + total_size = math_ops.reduce_sum(weights, axis=reduce_axis) + crossent = math_ops.div_no_nan(crossent, total_size) + elif sum_over_timesteps or sum_over_batch: + reduce_axis = [0] if sum_over_batch else [1] + crossent = math_ops.reduce_sum(crossent, axis=reduce_axis) + total_count = math_ops.cast( + math_ops.count_nonzero(weights, axis=reduce_axis), + dtype=crossent.dtype) + crossent = math_ops.div_no_nan(crossent, total_count) return crossent + + +class SequenceLoss(Loss): + """Weighted cross-entropy loss for a sequence of logits.""" + + def __init__(self, + average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=True, + softmax_loss_function=None, + name=None): + super(SequenceLoss, self).__init__(name=name) + self.average_across_timesteps = average_across_timesteps + self.average_across_batch = average_across_batch + self.sum_over_timesteps = sum_over_timesteps + self.sum_over_batch = sum_over_batch + self.softmax_loss_function = softmax_loss_function + + def __call__(self, y_true, y_pred, sample_weight=None): + """Override the parent __call__ to have a customized reduce behavior.""" + return sequence_loss(y_pred, y_true, sample_weight, + average_across_timesteps=self.average_across_timesteps, + average_across_batch=self.average_across_batch, + sum_over_timesteps=self.sum_over_timesteps, + sum_over_batch=self.sum_over_batch, + softmax_loss_function=self.softmax_loss_function, + name=self.name) + + def call(self, y_true, y_pred): + # Skip this method since the __call__ contains real implementation. + pass diff --git a/tensorflow/contrib/session_bundle/exporter.py b/tensorflow/contrib/session_bundle/exporter.py index 08983337fccc138d40eb959cecc5bf9e47cf6cac..f3efd292cf5acba4319c8a5545a7f70fae4b5ce1 100644 --- a/tensorflow/contrib/session_bundle/exporter.py +++ b/tensorflow/contrib/session_bundle/exporter.py @@ -304,10 +304,10 @@ class Exporter(object): def parser(path): if os.name == "nt": match = re.match( - "^" + export_dir_base.replace("\\", "/") + "/(\\d{8})$", + r"^" + export_dir_base.replace("\\", "/") + r"/(\d{8})$", path.path.replace("\\", "/")) else: - match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path) + match = re.match(r"^" + export_dir_base + r"/(\d{8})$", path.path) if not match: return None return path._replace(export_version=int(match.group(1))) diff --git a/tensorflow/contrib/session_bundle/gc_test.py b/tensorflow/contrib/session_bundle/gc_test.py index 8faf3ef3d4cd7ee0096265283070e25d06782254..02725bb1cbb4ef9ace29dcc58f6d23fb241d96b2 100644 --- a/tensorflow/contrib/session_bundle/gc_test.py +++ b/tensorflow/contrib/session_bundle/gc_test.py @@ -104,7 +104,7 @@ class GcTest(test_util.TensorFlowTestCase): # create a simple parser that pulls the export_version from the directory. def parser(path): - match = re.match("^" + base_dir + "/(\\d+)$", path.path) + match = re.match(r"^" + base_dir + r"/(\d+)$", path.path) if not match: return None return path._replace(export_version=int(match.group(1))) diff --git a/tensorflow/contrib/sparsemax/BUILD b/tensorflow/contrib/sparsemax/BUILD index d7ba754f701d4b433e35ad8396eae7ee6132b97f..ed4eca1a60a6f0ccf629d8aa7906c02092e25ba0 100644 --- a/tensorflow/contrib/sparsemax/BUILD +++ b/tensorflow/contrib/sparsemax/BUILD @@ -49,6 +49,9 @@ cuda_py_tests( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], + tags = [ + "oss_serial", + ], ) cuda_py_tests( @@ -64,4 +67,7 @@ cuda_py_tests( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], + tags = [ + "oss_serial", + ], ) diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py index d8236a0a6fa6d0d0e383e454eb0146bb10b6f49d..0d87cea9fbaa8fe28b55ec996414a568d39efee3 100644 --- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py +++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py @@ -50,9 +50,10 @@ def _accuracy(predictions, targets, weights=None): def _r2(probabilities, targets, weights=None): targets = math_ops.to_float(targets) y_mean = math_ops.reduce_mean(targets, 0) - squares_total = math_ops.reduce_sum(math_ops.square(targets - y_mean), 0) + squares_total = math_ops.reduce_sum( + math_ops.squared_difference(targets, y_mean), 0) squares_residuals = math_ops.reduce_sum( - math_ops.square(targets - probabilities), 0) + math_ops.squared_difference(targets, probabilities), 0) score = 1 - math_ops.reduce_sum(squares_residuals / squares_total) return metrics.mean(score, weights=weights) diff --git a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h index e04eb60f9b27cfd8b6b4e1502594d4d310ae55cc..774da472f1543f938d1b607ebdef008f7b540211 100644 --- a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h +++ b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h @@ -18,10 +18,10 @@ #include #include "tensorflow/contrib/tensor_forest/kernels/data_spec.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/random/distribution_sampler.h" #include "tensorflow/core/lib/random/simple_philox.h" #include "tensorflow/core/lib/strings/strcat.h" diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h index d3edb43733761a906c6e5bf8b65f76e3e1ae56fc..3100a5a0e5da1103b61bd089cd433721686b9e72 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h @@ -32,7 +32,7 @@ class DecisionTreeResource : public ResourceBase { // Constructor. explicit DecisionTreeResource(const TensorForestParams& params); - string DebugString() override { + string DebugString() const override { return strings::StrCat("DecisionTree[size=", decision_tree_->decision_tree().nodes_size(), "]"); } diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h index eea0be27caf0a022ba7acaacd359c75a2df4eedb..44f2b3f473b9eced06bd800b9cf0a5a0825ec3eb 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h @@ -40,7 +40,7 @@ class FertileStatsResource : public ResourceBase { model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(params_); } - string DebugString() override { return "FertileStats"; } + string DebugString() const override { return "FertileStats"; } void ExtractFromProto(const FertileStats& stats); diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 784acce444a8d0c066f1b7ae6c1b5d7d65405549..67461450f8ae53739f619622de8751b654dbc082 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -11,18 +11,12 @@ exports_files(["LICENSE"]) load( "//tensorflow:tensorflow.bzl", - "tf_cc_test", "tf_copts", "tf_cuda_library", - "tf_custom_op_library", "tf_custom_op_library_additional_deps", - "tf_gen_op_libs", - "tf_gen_op_wrapper_py", ) load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_tests") -load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") -load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load( "@local_config_tensorrt//:build_defs.bzl", @@ -33,127 +27,17 @@ exports_files(glob([ "test/testdata/*", ])) -tf_cuda_cc_test( - name = "tensorrt_test_cc", - size = "small", - srcs = ["tensorrt_test.cc"], - tags = [ - "no_windows", - "nomac", - ], - deps = [ - "//tensorflow/core:gpu_init", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ] + if_tensorrt([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_tensorrt//:nv_infer", - ]), -) - -tf_custom_op_library( - name = "python/ops/_trt_engine_op.so", - srcs = [ - "ops/trt_engine_op.cc", - ], - deps = [ - ":trt_shape_function", - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), -) - tf_cuda_library( name = "trt_shape_function", srcs = ["shape_fn/trt_shfn.cc"], hdrs = ["shape_fn/trt_shfn.h"], visibility = ["//visibility:public"], deps = [ - ":trt_logging", - ":trt_plugins", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]) + tf_custom_op_library_additional_deps(), -) - -cc_library( - name = "trt_engine_op_kernel", - srcs = [ - "kernels/trt_engine_op.cc", - ], - hdrs = [ - "kernels/trt_engine_op.h", - ], - copts = tf_copts(), - visibility = ["//visibility:public"], - deps = [ - ":test_utils", - ":trt_allocator", - ":trt_conversion", - ":trt_logging", - ":trt_plugins", - ":trt_resources", - ":utils", - "//tensorflow/core:gpu_headers_lib", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:stream_executor_headers_lib", - "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/compiler/tf2tensorrt:trt_logging", + "//tensorflow/compiler/tf2tensorrt:trt_plugins", ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", + "@local_config_tensorrt//:tensorrt", ]) + tf_custom_op_library_additional_deps(), - # TODO(laigd): fix this by merging header file in cc file. - alwayslink = 1, # buildozer: disable=alwayslink-with-hdrs -) - -tf_gen_op_libs( - op_lib_names = [ - "trt_engine_op", - ], -) - -tf_cuda_library( - name = "trt_logging", - srcs = ["log/trt_logger.cc"], - hdrs = ["log/trt_logger.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), -) - -tf_gen_op_wrapper_py( - name = "trt_engine_op", - deps = [ - ":trt_engine_op_op_lib", - ":trt_logging", - ":trt_shape_function", - ], -) - -tf_custom_op_py_library( - name = "trt_engine_op_loader", - srcs = ["python/ops/trt_engine_op.py"], - dso = [ - ":python/ops/_trt_engine_op.so", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), - kernels = [ - ":trt_engine_op_kernel", - ":trt_engine_op_op_lib", - ":trt_shape_function", - ], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:resources", - ], ) py_library( @@ -175,8 +59,8 @@ py_library( name = "trt_ops_py", srcs_version = "PY2AND3", deps = [ - ":trt_engine_op", - ":trt_engine_op_loader", + "//tensorflow/compiler/tf2tensorrt:trt_ops", + "//tensorflow/compiler/tf2tensorrt:trt_ops_loader", ], ) @@ -205,247 +89,13 @@ tf_py_wrap_cc( "//tensorflow/python:platform/base.i", ], deps = [ - ":test_utils", - ":trt_conversion", - ":trt_engine_op_kernel", + "//tensorflow/compiler/tf2tensorrt:test_utils", + "//tensorflow/compiler/tf2tensorrt:trt_conversion", + "//tensorflow/compiler/tf2tensorrt:trt_op_kernels", "//third_party/python_runtime:headers", ], ) -tf_cuda_library( - name = "trt_resources", - srcs = [ - "resources/trt_int8_calibrator.cc", - "resources/trt_resource_manager.cc", - ], - hdrs = [ - "resources/trt_int8_calibrator.h", - "resources/trt_resource_manager.h", - "resources/trt_resources.h", - ], - deps = [ - ":trt_allocator", - ":trt_logging", - ":utils", - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:framework_lite", - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), -) - -tf_cuda_library( - name = "trt_allocator", - srcs = ["resources/trt_allocator.cc"], - hdrs = ["resources/trt_allocator.h"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:framework_lite", - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), -) - -tf_cc_test( - name = "trt_allocator_test", - size = "small", - srcs = ["resources/trt_allocator_test.cc"], - tags = [ - "no_windows", - "nomac", - ], - deps = [ - ":trt_allocator", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -# Library for the node-level conversion portion of TensorRT operation creation -tf_cuda_library( - name = "trt_conversion", - srcs = [ - "convert/convert_graph.cc", - "convert/convert_nodes.cc", - "convert/trt_optimization_pass.cc", - ], - hdrs = [ - "convert/convert_graph.h", - "convert/convert_nodes.h", - "convert/trt_optimization_pass.h", - ], - deps = [ - ":segment", - ":test_utils", - ":trt_allocator", - ":trt_plugins", - ":trt_logging", - ":trt_resources", - ":utils", - "//tensorflow/core/grappler/clusters:cluster", - "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", - "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", - "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/grappler:utils", - "//tensorflow/core:framework", - "//tensorflow/core:framework_lite", - "//tensorflow/core:gpu_runtime", - "//tensorflow/core:graph", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler:devices", - "//tensorflow/core/grappler/clusters:virtual_cluster", - "//tensorflow/core/grappler/costs:graph_properties", - "//tensorflow/core/grappler/optimizers:meta_optimizer", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]) + tf_custom_op_library_additional_deps(), -) - -tf_cuda_cc_test( - name = "convert_graph_test", - size = "medium", - srcs = ["convert/convert_graph_test.cc"], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], - deps = [ - ":trt_conversion", - "@com_google_googletest//:gtest", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:scope", - "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/grappler/clusters:cluster", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:direct_session", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), -) - -tf_cuda_cc_test( - name = "convert_nodes_test", - size = "medium", - srcs = ["convert/convert_nodes_test.cc"], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], - deps = [ - ":trt_logging", - ":trt_conversion", - ":trt_plugins", - "@com_google_googletest//:gtest", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", - "//tensorflow/core/grappler/costs:graph_properties", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ] + if_tensorrt([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_tensorrt//:nv_infer", - ]), -) - -# Library for the segmenting portion of TensorRT operation creation -cc_library( - name = "segment", - srcs = ["segment/segment.cc"], - hdrs = [ - "segment/segment.h", - "segment/union_find.h", - ], - deps = [ - "//tensorflow/core:graph", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:protos_all_cc", - "@protobuf_archive//:protobuf_headers", - ], -) - -tf_cc_test( - name = "segment_test", - size = "small", - srcs = ["segment/segment_test.cc"], - tags = [ - "no_windows", - "nomac", - ], - deps = [ - ":segment", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:scope", - "//tensorflow/core:core_cpu", - "//tensorflow/core:lib", - "//tensorflow/core:ops", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) - -# Library for the plugin factory -tf_cuda_library( - name = "trt_plugins", - srcs = [ - "plugin/trt_plugin.cc", - "plugin/trt_plugin_factory.cc", - "plugin/trt_plugin_utils.cc", - ], - hdrs = [ - "plugin/trt_plugin.h", - "plugin/trt_plugin_factory.h", - "plugin/trt_plugin_utils.h", - ], - deps = [ - "//tensorflow/core:framework_lite", - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), -) - -tf_cuda_cc_test( - name = "trt_plugin_factory_test", - size = "small", - srcs = ["plugin/trt_plugin_factory_test.cc"], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], - deps = [ - ":trt_plugins", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ] + if_tensorrt([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_tensorrt//:nv_infer", - ]), -) - py_library( name = "tf_trt_integration_test_base", srcs = ["test/tf_trt_integration_test_base.py"], @@ -491,6 +141,11 @@ cuda_py_tests( "test/binary_tensor_weight_broadcast_test.py", "test/concatenation_test.py", "test/const_broadcast_test.py", + "test/conv2d_test.py", + "test/dynamic_input_shapes_test.py", + "test/identity_output_test.py", + "test/int32_test.py", + "test/lru_cache_test.py", "test/manual_test.py", "test/memory_alignment_test.py", "test/multi_connection_neighbor_engine_test.py", @@ -498,6 +153,8 @@ cuda_py_tests( "test/quantization_test.py", "test/rank_two_test.py", "test/reshape_transpose_test.py", + "test/topk_test.py", + "test/unary_test.py", "test/vgg_block_nchw_test.py", "test/vgg_block_test.py", ], @@ -513,25 +170,6 @@ cuda_py_tests( ], ) -cuda_py_tests( - name = "tf_trt_integration_test_no_oss", - srcs = [ - "test/unary_test.py", - ], - additional_deps = [ - ":tf_trt_integration_test_base", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - ], - tags = [ - "no_cuda_on_cpu_tap", - "no_oss", # TODO(b/117274186): re-enable in OSS after crash fixed - "no_pip", # TODO(b/117274186): re-enable in OSS after crash fixed - "no_windows", - "nomac", - ], -) - cuda_py_test( name = "quantization_mnist_test", srcs = ["test/quantization_mnist_test.py"], @@ -556,22 +194,20 @@ cuda_py_test( ], ) -cc_library( - name = "utils", - srcs = ["convert/utils.cc"], - hdrs = ["convert/utils.h"], - copts = tf_copts(), - deps = [ - "//tensorflow/core:lib", - ], +# The following rules forward the libraries that were moved in order to not +# break other internal targets. + +alias( + name = "trt_conversion", + actual = "//tensorflow/compiler/tf2tensorrt:trt_conversion", ) -cc_library( - name = "test_utils", - srcs = ["test/utils.cc"], - hdrs = ["test/utils.h"], - deps = [ - "//tensorflow/core:lib", - "@com_googlesource_code_re2//:re2", - ], +alias( + name = "trt_op_kernels", + actual = "//tensorflow/compiler/tf2tensorrt:trt_op_kernels", +) + +alias( + name = "trt_engine_op_op_lib", + actual = "//tensorflow/compiler/tf2tensorrt:trt_engine_op_op_lib", ) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index 69058c5826822c519a69d50860c06b8ab3ec6578..0a2cf105baf5efb62d0c535c1f2d081973ec0ea3 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -45,10 +45,10 @@ tf_custom_op_library( "inc_op_kernel.cu.cc", ], deps = [ - "//tensorflow/contrib/tensorrt:trt_plugins", + "//tensorflow/compiler/tf2tensorrt:trt_plugins", "//tensorflow/core:framework_lite", ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", + "@local_config_tensorrt//:tensorrt", ]), ) @@ -64,10 +64,10 @@ tf_kernel_library( "inc_op_kernel.cu.cc", ], deps = [ - "//tensorflow/contrib/tensorrt:trt_plugins", + "//tensorflow/compiler/tf2tensorrt:trt_plugins", "//tensorflow/core:stream_executor_headers_lib", ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", + "@local_config_tensorrt//:tensorrt", ]) + tf_custom_op_library_additional_deps(), ) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc index 8d4c893af56689185da72398919e2241d451594b..7c9075142a02546ddd580e861ac87cb86badd739 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h index 189e9c939b9ffd4450f7ba95fe1abdbbc049b430..fb048d7b19da0f010ed918b147013b20d37ed0dd 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py index 7cdfe2b1a612be2eec473d806d0eb44b611ca68a..75490aecfbe84810520c82597d127a36d36de3ee 100644 --- a/tensorflow/contrib/tensorrt/python/__init__.py +++ b/tensorflow/contrib/tensorrt/python/__init__.py @@ -19,7 +19,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long -from tensorflow.contrib.tensorrt.python.ops import trt_engine_op +from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops from tensorflow.contrib.tensorrt.python.trt_convert import add_test_value from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph from tensorflow.contrib.tensorrt.python.trt_convert import clear_test_values diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 203b2697babe32b45523109708cbf062dceee33b..49d72232aa0cfba3f5bf533de04f4d50e65275fd 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -45,12 +45,19 @@ from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import saver -if _six.PY2: - _to_bytes = lambda s: s - _to_string = lambda s: s -else: - _to_bytes = lambda s: s.encode("utf-8", errors="surrogateescape") - _to_string = lambda s: s.decode("utf-8") + +def _to_bytes(s): + """Encode s if it is a sequence of chars.""" + if isinstance(s, _six.text_type): + return s.encode("utf-8", errors="surrogateescape") + return s + + +def _to_string(s): + """Decode s if it is a sequence of bytes.""" + if isinstance(s, _six.binary_type): + return s.decode("utf-8") + return s class TrtPrecisionMode(object): @@ -70,7 +77,7 @@ def get_tensorrt_rewriter_config(rewriter_config=None, minimum_segment_size=3, is_dynamic_op=False, maximum_cached_engines=1, - cached_engine_batch_sizes=None, + cached_engine_batches=None, use_calibration=True): """Returns a RewriterConfig proto for TRT transformation. @@ -90,9 +97,9 @@ def get_tensorrt_rewriter_config(rewriter_config=None, If the number of cached engines is already at max but none of them can serve the input, the TRTEngineOp will fall back to run the TF function based on which the TRTEngineOp is created. - cached_engine_batch_sizes: a list of batch sizes used to create cached + cached_engine_batches: a list of batch sizes used to create cached engines, only used when is_dynamic_op is True. The length of the list - should be smaller than maximum_cached_engines, and the dynamic TRT op will + should be <= maximum_cached_engines, and the dynamic TRT op will use this list to determine the batch sizes of the cached engines, instead of making the decision on the fly. This is useful when we know the most common batch size(s) the application is going to generate. @@ -143,14 +150,14 @@ def get_tensorrt_rewriter_config(rewriter_config=None, "max_workspace_size_bytes"].i = max_workspace_size_bytes optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode) optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines - if cached_engine_batch_sizes: - if not isinstance(cached_engine_batch_sizes, list): - raise TypeError("cached_engine_batch_sizes should be a list.") - if len(cached_engine_batch_sizes) > maximum_cached_engines: - raise ValueError("cached_engine_batch_sizes should not contain more than " + if cached_engine_batches: + if not isinstance(cached_engine_batches, list): + raise TypeError("cached_engine_batches should be a list.") + if len(cached_engine_batches) > maximum_cached_engines: + raise ValueError("cached_engine_batches should not contain more than " "maximum_cached_engines items.") optimizer.parameter_map["cached_engine_batches"].list.i.extend( - cached_engine_batch_sizes) + cached_engine_batches) optimizer.parameter_map["use_calibration"].b = use_calibration return rewriter_config_with_trt @@ -163,7 +170,7 @@ def create_inference_graph(input_graph_def, minimum_segment_size=3, is_dynamic_op=False, maximum_cached_engines=1, - cached_engine_batch_sizes=None, + cached_engine_batches=None, use_calibration=True, input_saved_model_dir=None, input_saved_model_tags=None, @@ -190,9 +197,9 @@ def create_inference_graph(input_graph_def, If the number of cached engines is already at max but none of them can serve the input, the TRTEngineOp will fall back to run the TF function based on which the TRTEngineOp is created. - cached_engine_batch_sizes: a list of batch sizes used to create cached + cached_engine_batches: a list of batch sizes used to create cached engines, only used when is_dynamic_op is True. The length of the list - should be smaller than maximum_cached_engines, and the dynamic TRT op will + should be <= maximum_cached_engines, and the dynamic TRT op will use this list to determine the batch sizes of the cached engines, instead of making the decision on the fly. This is useful when we know the most common batch size(s) the application is going to generate. @@ -354,7 +361,7 @@ def create_inference_graph(input_graph_def, rewriter_config_with_trt = get_tensorrt_rewriter_config( rewriter_config, max_batch_size, max_workspace_size_bytes, precision_mode, minimum_segment_size, is_dynamic_op, maximum_cached_engines, - cached_engine_batch_sizes, use_calibration) + cached_engine_batches, use_calibration) session_config_with_trt.graph_options.rewrite_options.CopyFrom( rewriter_config_with_trt) diff --git a/tensorflow/contrib/tensorrt/python/trt_convert_test.py b/tensorflow/contrib/tensorrt/python/trt_convert_test.py index a7b2d2ea50543ba85c5a13dd6ca320e794ca47f1..abd822c7b71b4d7cca59482bdb51a922a28d480c 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert_test.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert_test.py @@ -20,10 +20,10 @@ from __future__ import print_function import os -from tensorflow.contrib.tensorrt.python import trt_convert # pylint: disable=unused-import -from tensorflow.contrib.tensorrt.python.ops import trt_engine_op +from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops # pylint: enable=unused-import +from tensorflow.contrib.tensorrt.python import trt_convert from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 @@ -57,7 +57,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase): minimum_segment_size=10, is_dynamic_op=True, maximum_cached_engines=2, - cached_engine_batch_sizes=[1, 128]) + cached_engine_batches=[1, 128]) self.assertEqual(["constfold", "layout", "constfold"], rewriter_cfg.optimizers) self.assertEqual(rewriter_config_pb2.RewriterConfig.ONE, @@ -84,8 +84,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase): trt_optimizer.parameter_map["precision_mode"].s) self.assertEqual(2, trt_optimizer.parameter_map["maximum_cached_engines"].i) self.assertEqual( - [1, 128], - trt_optimizer.parameter_map["cached_engine_batches"].list.i) + [1, 128], trt_optimizer.parameter_map["cached_engine_batches"].list.i) def _GetConfigProto(self): """Get ConfigProto for session creation.""" @@ -239,8 +238,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase): # Run with batch size 2, a new engine is created and cached. self._TestRun(sess, 2, True) # Run with batch size 3, since the number of cached engines has reached - # the max, it should fall back to TF function. - self._TestRun(sess, 3, False) + # the max, it should evict an old engine and create a new one. + self._TestRun(sess, 3, True) # Test the output SavedModel with ops.Graph().as_default(): @@ -251,8 +250,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase): # Run with batch size 2, a new engine is created and cached. self._TestRun(sess, 2, True) # Run with batch size 3, since the number of cached engines has reached - # the max, it should fall back to TF function. - self._TestRun(sess, 3, False) + # the max, it should evict an old engine and create a new one. + self._TestRun(sess, 3, True) def testCreateInferenceGraph_StaticOp(self): if not trt_convert.is_tensorrt_enabled(): diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h deleted file mode 100644 index aac9e5c7bd725fc10bcaa04536ebc7be071b4d4c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCES_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCES_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" -#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" -#include "tensorflow/core/framework/resource_mgr.h" - -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT - -#include "tensorrt/include/NvInfer.h" - -namespace tensorflow { -namespace tensorrt { - -class TRTCalibrationResource : public tensorflow::ResourceBase { - public: - ~TRTCalibrationResource() { - LOG(INFO) << "Destroying Calibration Resource " << std::endl - << DebugString(); - builder_.reset(); - engine_.reset(); - // We need to manually destroy the builder and engine before the allocator - // is destroyed. - allocator_.reset(); - } - - string DebugString() override { - std::stringstream oss; - using std::dec; - using std::endl; - using std::hex; - oss << " Calibrator = " << hex << calibrator_.get() << dec << endl - << " Builder = " << hex << builder_.get() << dec << endl - << " Engine = " << hex << engine_.get() << dec << endl - << " Logger = " << hex << &logger_ << dec << endl - << " Allocator = " << hex << allocator_.get() << dec << endl - << " Thread = " << hex << thr_.get() << dec << endl; - return oss.str(); - } - - std::unique_ptr calibrator_; - TrtUniquePtrType builder_; - TrtUniquePtrType engine_; - std::unique_ptr allocator_; - tensorflow::tensorrt::Logger logger_; - // TODO(sami): Use threadpool threads! - std::unique_ptr thr_; -}; - -} // namespace tensorrt -} // namespace tensorflow - -#endif -#endif -#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCES_H_ diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index f30dba59ad55317d7ad7730e4dc66c9aba4e6a6b..5c60d6b589ed6a16276226726d989e949bcbf9d7 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -14,14 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include #include #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorrt/include/NvInfer.h" diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py index ff317e43e1e6ff1c0b869ae8dc6d1fda8f0ce126..17e0b6f4d2c4bbaf56ef143b78c543c9e130b458 100644 --- a/tensorflow/contrib/tensorrt/test/base_test.py +++ b/tensorflow/contrib/tensorrt/test/base_test.py @@ -68,9 +68,9 @@ class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], - input_dims=[input_dims], + input_dims=[[input_dims]], output_names=[output_name], - expected_output_dims=[(100, 6, 6, 6)]) + expected_output_dims=[[[100, 6, 6, 6]]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" @@ -125,9 +125,9 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], - input_dims=[input_dims], + input_dims=[[input_dims]], output_names=[output_name], - expected_output_dims=[(100, 12, 12, 6)]) + expected_output_dims=[[[100, 12, 12, 6]]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" @@ -183,9 +183,9 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], - input_dims=[input_dims], + input_dims=[[input_dims]], output_names=[output_name], - expected_output_dims=[tuple(input_dims)]) + expected_output_dims=[[input_dims]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" @@ -253,9 +253,9 @@ class ConstInputTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], - input_dims=[input_dims], + input_dims=[[input_dims]], output_names=[output_name], - expected_output_dims=[tuple(input_dims)]) + expected_output_dims=[[input_dims]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" @@ -286,9 +286,9 @@ class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], - input_dims=[input_dims], + input_dims=[[input_dims]], output_names=[output_name], - expected_output_dims=[tuple(input_dims)]) + expected_output_dims=[[input_dims]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" @@ -320,9 +320,9 @@ class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], - input_dims=[input_dims], + input_dims=[[input_dims]], output_names=[output_name], - expected_output_dims=[tuple(input_dims)]) + expected_output_dims=[[input_dims]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" @@ -369,9 +369,9 @@ class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], - input_dims=[input_dims], + input_dims=[[input_dims]], output_names=[output_name], - expected_output_dims=[tuple(input_dims)]) + expected_output_dims=[[input_dims]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" diff --git a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py index f42308ecb7c8f8a107e78008abd3f470ddc85975..46e3407d9669085a9737bacbeec1a20765ef88cc 100644 --- a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py +++ b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py @@ -71,9 +71,9 @@ class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name, w1_name, w2_name], - input_dims=[input_dims, w1_dims, w2_dims], + input_dims=[[input_dims, w1_dims, w2_dims]], output_names=[output_name], - expected_output_dims=[(12, 5, 8, 7)]) + expected_output_dims=[[[12, 5, 8, 7]]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" @@ -86,28 +86,6 @@ class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase): """Return the expected engines to run.""" return ["TRTEngineOp_1"] - def ShouldRunTest(self, run_params): - """Whether to run the test.""" - # TODO(aaroey): Trt library will fail like: - # - # ../builder/cudnnBuilder2.cpp:685: - # virtual std::vector> - # nvinfer1::builder::Node::getSupportedFormats( - # const nvinfer1::query::Ports&, - # const nvinfer1::cudnn::HardwareContext&, - # nvinfer1::builder::Format::Type, - # const nvinfer1::builder::FormatTypeHack&) const: - # Assertion `sf' failed. - # - # To reproduce, run: - # bazel test -c opt --copt=-mavx \ - # --test_arg=BatchMatMulTest.testTfTrt_ToolConversion_INT8_DynamicEngine \ - # tensorflow/contrib/tensorrt:batch_matmul_test - # - # Investigate and fix it. - return not trt_test.IsQuantizationMode(run_params.precision_mode) - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py index 053b38ff1c0578c58f39dd6dc0630d1401a105af..ca23629efacba1df27ffb466d24b189d6074a917 100644 --- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py +++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py @@ -111,9 +111,9 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], - input_dims=[input_dims], + input_dims=[[input_dims]], output_names=[output_name], - expected_output_dims=[(4, 6680)]) + expected_output_dims=[[[4, 6680]]]) def GetConversionParams(self, run_params): """Return a ConversionParams for test.""" @@ -130,12 +130,6 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): """Return the expected engines to build.""" return ["TRTEngineOp_0"] - def ShouldRunTest(self, run_params): - """Whether to run the test.""" - # TODO(aaroey): Trt 4.0 forbids conversion for tensors with rank <3 in int8 - # mode, which is a bug. Re-enable this when trt library is fixed. - return not trt_test.IsQuantizationMode(run_params.precision_mode) - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py index 169835956c046dd675e967daa05fd81405662e38..846fd009db07b151e1eca794e9a8a38ff834a465 100644 --- a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py +++ b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py @@ -63,9 +63,9 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], - input_dims=[input_dims], + input_dims=[[input_dims]], output_names=[output_name], - expected_output_dims=[(5, 23040)]) + expected_output_dims=[[[5, 23040]]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" diff --git a/tensorflow/contrib/tensorrt/test/concatenation_test.py b/tensorflow/contrib/tensorrt/test/concatenation_test.py index c3576f81d97afe7e0e42cd10413971911e97774c..5d8742ae356c091ba831bbd48741dee34cd39d08 100644 --- a/tensorflow/contrib/tensorrt/test/concatenation_test.py +++ b/tensorflow/contrib/tensorrt/test/concatenation_test.py @@ -73,9 +73,9 @@ class ConcatenationTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], - input_dims=[input_dims], + input_dims=[[input_dims]], output_names=[output_name], - expected_output_dims=[(2, 126)]) + expected_output_dims=[[[2, 126]]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" diff --git a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py index c1c883312d867b60b88ac14318041f9750ca41e6..9137d0072f66321d8420b7caac6acc329541123f 100644 --- a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py +++ b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py @@ -58,9 +58,9 @@ class ConstBroadcastTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], - input_dims=[input_dims], + input_dims=[[input_dims]], output_names=[output_name], - expected_output_dims=[(5, 12, 12, 1)]) + expected_output_dims=[[[5, 12, 12, 1]]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" diff --git a/tensorflow/contrib/tensorrt/test/conv2d_test.py b/tensorflow/contrib/tensorrt/test/conv2d_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e7993b4620931736cd872bfffb4ebe177555fcd2 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/conv2d_test.py @@ -0,0 +1,191 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.platform import test + + +def conv2d_layer(inputs, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1), + name=None): + dtype = inputs.dtype + c_axis = -1 if data_format == "channels_last" else 1 + nchan = inputs.shape[c_axis] + weights_shape = (kernel_size[0], kernel_size[1], nchan, filters) + weights = constant_op.constant(np.random.randn(*weights_shape), dtype=dtype) + padding = padding.upper() + if data_format == "channels_last": + strides = [1] + list(strides) + [1] + dilations = [1] + list(dilation_rate) + [1] + data_format = "NHWC" + else: + strides = [1, 1] + list(strides) + dilations = [1, 1] + list(dilation_rate) + data_format = "NCHW" + return gen_nn_ops.conv2d( + inputs, + weights, + strides=strides, + padding=padding, + dilations=dilations, + data_format=data_format) + + +def div_round_up(n, d): + return (n - 1) // d + 1 + + +def build_graph(input_dims, + dtype, + num_filters, + data_format, + kernel_sizes, + dilation_rates, + padding="same"): + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtype, shape=[None] + input_dims[1:], name="input") + with g.device("/GPU:0"): + results = [] + for kernel_size in kernel_sizes: + for dilation_rate in dilation_rates: + result = conv2d_layer(inp, num_filters, kernel_size, (1, 1), padding, + data_format, dilation_rate) + results.append(result) + output = sum(results) + output = array_ops.identity(output, name="output") + return g + + +class Conv2DNCHWTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Testing conversion of Conv2D (data_format=NCHW) in TF-TRT conversion.""" + np.random.seed(1234) + input_dims = [13, 3, 7, 11] + g = build_graph( + input_dims=input_dims, + dtype=dtypes.float32, + num_filters=5, + data_format="channels_first", + kernel_sizes=[(3, 3), (3, 2)], + dilation_rates=[(1, 1), (2, 3)]) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=["input"], + input_dims=[[input_dims]], + output_names=["output"], + expected_output_dims=[[[13, 5, 7, 11]]]) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return ["TRTEngineOp_0"] + + +class Conv2DNHWCTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Testing conversion of Conv2D (data_format=NCHW) in TF-TRT conversion.""" + np.random.seed(1234) + input_dims = [13, 7, 11, 3] + g = build_graph( + input_dims=input_dims, + dtype=dtypes.float32, + num_filters=5, + data_format="channels_last", + kernel_sizes=[(3, 3), (3, 2)], + dilation_rates=[(1, 1), (2, 3)]) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=["input"], + input_dims=[[input_dims]], + output_names=["output"], + expected_output_dims=[[[13, 7, 11, 5]]]) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return ["TRTEngineOp_0"] + + +class Conv2DStridedNCHWTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Testing conversion of strided Conv2D (data_format=NCHW) in TF-TRT + + conversion. + """ + np.random.seed(1234) + dtype = dtypes.float32 + input_name = "input" + n, c, h, w = 13, 3, 7, 11 + num_filters = 5 + input_dims = [n, c, h, w] + output_name = "output" + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtype, shape=[None] + input_dims[1:], name=input_name) + with g.device("/GPU:0"): + output = inp + output = conv2d_layer( + output, + num_filters, (3, 2), + strides=(2, 2), + padding="same", + data_format="channels_first") + h = div_round_up(h, 2) + w = div_round_up(w, 2) + output = conv2d_layer( + output, + num_filters, (3, 3), + strides=(2, 2), + dilation_rate=(2, 3), + padding="same", + data_format="channels_first") + h = div_round_up(h, 2) + w = div_round_up(w, 2) + output = array_ops.identity(output, name=output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[[input_dims]], + output_names=[output_name], + expected_output_dims=[[[n, num_filters, h, w]]]) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return ["TRTEngineOp_0"] + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/dynamic_input_shapes_test.py b/tensorflow/contrib/tensorrt/test/dynamic_input_shapes_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cc28cd6087997359e81ffaa6dc8bd958109cc565 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/dynamic_input_shapes_test.py @@ -0,0 +1,107 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script to test TF-TRT INT8 conversion without calibration on Mnist model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test + + +class DynamicInputShapesTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + # TODO(laigd): we should test the following cases: + # - batch size is not changed, other dims are changing + # - batch size is decreasing, other dims are identical + # - batch size is decreasing, other dims are changing + # - batch size is increasing, other dims are identical + # - batch size is increasing, other dims are changing + input_dims = [[[1, 5, 5, 1]], [[10, 5, 5, 1]], [[3, 5, 5, 1]], + [[1, 5, 5, 1]], [[1, 3, 1, 1]], [[2, 9, 9, 1]], + [[1, 224, 224, 1]], [[1, 128, 224, 1]]] + expected_output_dims = input_dims + + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder( + shape=(None, None, None, 1), dtype=dtypes.float32, name="input") + conv_filter1 = constant_op.constant( + np.ones([3, 3, 1, 8]), name="weights1", dtype=dtypes.float32) + bias1 = constant_op.constant(np.random.randn(8), dtype=dtypes.float32) + x = nn.conv2d( + input=x, + filter=conv_filter1, + strides=[1, 1, 1, 1], + padding="SAME", + name="conv") + x = nn.bias_add(x, bias1) + x = nn.relu(x) + conv_filter2 = constant_op.constant( + np.ones([3, 3, 8, 1]), name="weights2", dtype=dtypes.float32) + bias2 = constant_op.constant(np.random.randn(1), dtype=dtypes.float32) + x = nn.conv2d( + input=x, + filter=conv_filter2, + strides=[1, 1, 1, 1], + padding="SAME", + name="conv") + x = nn.bias_add(x, bias2) + x = array_ops.identity(x, name="output") + + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=["input"], + input_dims=input_dims, + output_names=["output"], + expected_output_dims=expected_output_dims) + + def GetConversionParams(self, run_params): + """Return a ConversionParams for test.""" + conversion_params = super(DynamicInputShapesTest, + self).GetConversionParams(run_params) + return conversion_params._replace( + maximum_cached_engines=10, + # Disable layout optimizer, since it will convert BiasAdd with NHWC + # format to NCHW format under four dimentional input. + rewriter_config=trt_test.OptimizerDisabledRewriterConfig()) + + def ExpectedEnginesToBuild(self, run_params): + return ["TRTEngineOp_0"] + + def ShouldRunTest(self, run_params): + return (run_params.dynamic_engine and + not trt_test.IsQuantizationMode(run_params.precision_mode)) + + def ExpectedAbsoluteTolerance(self, run_params): + """The absolute tolerance to compare floating point results.""" + return 1.e-03 if run_params.precision_mode == "FP32" else 1.e-01 + + def ExpectedRelativeTolerance(self, run_params): + """The relative tolerance to compare floating point results.""" + return 1.e-03 if run_params.precision_mode == "FP32" else 1.e-01 + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/identity_output_test.py b/tensorflow/contrib/tensorrt/test/identity_output_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b568eeda945d997a832b7f71a5bfd8c42e127e65 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/identity_output_test.py @@ -0,0 +1,74 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""This test checks a situation where the same tensor is considered as an output + +multiple times because it has been duplicated by 2+ indentity ops. Previously, +the tensor would be renamed multiple times, overwriting the output binding name +which resulted in a runtime error when the binding would not be found. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class IdentityTest(trt_test.TfTrtIntegrationTestBase): + + def _ConstOp(self, shape): + return constant_op.constant(np.random.randn(*shape), dtype=dtypes.float32) + + def GetParams(self): + """Testing engine with the same tensor repeated as output via identity.""" + input_name = 'input' + input_dims = [100, 32] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + + b = self._ConstOp((32, 4)) + x1 = math_ops.matmul(x, b) + b = self._ConstOp((1, 4)) + x1 = x1 + b + + out1 = array_ops.identity(x1, name='output1') + out2 = array_ops.identity(x1, name='output2') + iden1 = array_ops.identity(x1) + out3 = array_ops.identity(iden1, name='output3') + + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[[input_dims]], + output_names=['output1', 'output2', 'output3'], + expected_output_dims=[[[100, 4]] * 3]) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return ['TRTEngineOp_0'] + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/tensorrt/test/int32_test.py b/tensorflow/contrib/tensorrt/test/int32_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8cf538703880b130322a7dd504094c7a298e6522 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/int32_test.py @@ -0,0 +1,82 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test conversion of graphs involving INT32 tensors and operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test + + +class ExcludeUnsupportedInt32Test(trt_test.TfTrtIntegrationTestBase): + + def _ConstOp(self, shape, dtype): + return constant_op.constant(np.random.randn(*shape), dtype=dtype) + + def GetParams(self): + """Test exclusion of ops which are not supported in INT32 mode by TF-TRT""" + input_name = 'input' + output_name = 'output' + input_dims = [100, 4] + dtype = dtypes.int32 + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + b = self._ConstOp((4, 10), dtype) + x = math_ops.matmul(x, b) + b = self._ConstOp((10,), dtype) + x = nn.bias_add(x, b) + x = array_ops.identity(x, name=output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[[input_dims]], + output_names=[output_name], + expected_output_dims=[[[100, 10]]]) + + def GetConversionParams(self, run_params): + """Return a ConversionParams for test.""" + conversion_params = super(ExcludeUnsupportedInt32Test, + self).GetConversionParams(run_params) + return conversion_params._replace( + max_batch_size=100, + maximum_cached_engines=1, + # Disable layout optimizer, since it will convert BiasAdd with NHWC + # format to NCHW format under four dimentional input. + rewriter_config=trt_test.OptimizerDisabledRewriterConfig()) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return [] + + def ShouldRunTest(self, run_params): + """Whether to run the test.""" + # TODO(aaroey): Trt 4.0 forbids conversion for tensors with rank <3 in int8 + # mode, which is a bug. Re-enable this when trt library is fixed. + return not trt_test.IsQuantizationMode(run_params.precision_mode) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/tensorrt/test/lru_cache_test.py b/tensorflow/contrib/tensorrt/test/lru_cache_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7702413e6cee667796b7fbf4121c6e0d9118d35c --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/lru_cache_test.py @@ -0,0 +1,78 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test LRUCache by running different input batch sizes on same network.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test + + +class LRUCacheTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + dtype = dtypes.float32 + input_name = "input" + input_dims = [[[1, 10, 10, 2]], [[2, 10, 10, 2]], [[4, 10, 10, 2]], + [[2, 10, 10, 2]]] + expected_output_dims = [[[1, 10, 10, 1]], [[2, 10, 10, 1]], [[4, 10, 10, + 1]], + [[2, 10, 10, 1]]] + output_name = "output" + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder( + dtype=dtype, shape=[None, 10, 10, 2], name=input_name) + conv_filter = constant_op.constant( + np.random.randn(3, 3, 2, 1), dtype=dtypes.float32) + x = nn.conv2d( + input=x, + filter=conv_filter, + strides=[1, 1, 1, 1], + padding="SAME", + name="conv") + bias = constant_op.constant( + np.random.randn(1, 10, 10, 1), dtype=dtypes.float32) + x = math_ops.add(x, bias) + x = nn.relu(x) + x = array_ops.identity(x, name="output") + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=input_dims, + output_names=[output_name], + expected_output_dims=expected_output_dims) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return ["TRTEngineOp_0"] + + def ShouldRunTest(self, run_params): + return (run_params.dynamic_engine and + not trt_test.IsQuantizationMode(run_params.precision_mode)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/manual_test.py b/tensorflow/contrib/tensorrt/test/manual_test.py index 1187c759b4b5483cbf5afe136401abe86d6ef989..aad7b9f30728cbb3f4ec5fa730c5dbe46fe9fc3f 100644 --- a/tensorflow/contrib/tensorrt/test/manual_test.py +++ b/tensorflow/contrib/tensorrt/test/manual_test.py @@ -67,9 +67,9 @@ class ManualTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=gdef, input_names=params_map['input_names'], - input_dims=params_map['input_dims'], + input_dims=[params_map['input_dims']], output_names=params_map['output_names'], - expected_output_dims=params_map['expected_output_dims']) + expected_output_dims=[params_map['expected_output_dims']]) def GetConversionParams(self, run_params): """Return a ConversionParams for test.""" diff --git a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py index 104bac43a0b1166dcddee9920991582f33e93316..cc64329bbd53eaaebf7929e48bbfa8d8beeeadff 100644 --- a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py +++ b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py @@ -62,9 +62,9 @@ class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], - input_dims=[input_dims], + input_dims=[[input_dims]], output_names=[output_name], - expected_output_dims=[(2, 15, 15, 10)]) + expected_output_dims=[[[2, 15, 15, 10]]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" diff --git a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py index 293f93d8a78bc8ab06002d6fc01cb8d6a0738698..a14bb0396ece74c8de73008d2007bce5c763b0ed 100644 --- a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py +++ b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py @@ -75,9 +75,9 @@ class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], - input_dims=[input_dims], + input_dims=[[input_dims]], output_names=[output_name], - expected_output_dims=[(2, 4, 5, 4)]) + expected_output_dims=[[[2, 4, 5, 4]]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py index 3e1e4b088ba200db2184dd64092cbc642a17cb3a..06a86bbb8df4c11a471475c040b74099a6fe2361 100644 --- a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py +++ b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py @@ -59,9 +59,9 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], - input_dims=[input_dims], + input_dims=[[input_dims]], output_names=[output_name], - expected_output_dims=[(2, 4, 5, 4)]) + expected_output_dims=[[[2, 4, 5, 4]]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" diff --git a/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py b/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py index e7d6ec4ad395d38a06f97020f2f363009f2286c7..d68211a7ee344f3d07d01e308ee60246a61816f6 100644 --- a/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py +++ b/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py @@ -18,10 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.tensorrt.python import trt_convert # pylint: disable=unused-import -from tensorflow.contrib.tensorrt.python.ops import trt_engine_op +from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops # pylint: enable=unused-import +from tensorflow.contrib.tensorrt.python import trt_convert from tensorflow.core.protobuf import config_pb2 from tensorflow.python import data from tensorflow.python import keras @@ -144,7 +144,10 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase): outputs=[OUTPUT_NODE_NAME], max_batch_size=max_batch_size, precision_mode='INT8', - max_workspace_size_bytes=4096 << 19, + # There is a 2GB GPU memory limit for each test, so we set + # max_workspace_size_bytes to 256MB to leave enough room for TF + # runtime to allocate GPU memory. + max_workspace_size_bytes=1 << 28, minimum_segment_size=2, use_calibration=False, ) @@ -271,7 +274,7 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase): num_epochs=None, model_dir=model_dir)['accuracy'] logging.info('accuracy_tf_native: %f', accuracy_tf_native) - self.assertAllClose(accuracy_tf_native, 0.9662) + self.assertAllClose(0.9662, accuracy_tf_native, rtol=1e-3, atol=1e-3) if trt_convert.get_linked_tensorrt_version()[0] < 5: return @@ -283,7 +286,7 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase): num_epochs=None, model_dir=model_dir)['accuracy'] logging.info('accuracy_tf_trt: %f', accuracy_tf_trt) - self.assertAllClose(accuracy_tf_trt, 0.9677) + self.assertAllClose(0.9675, accuracy_tf_trt, rtol=1e-3, atol=1e-3) if __name__ == '__main__': diff --git a/tensorflow/contrib/tensorrt/test/quantization_test.py b/tensorflow/contrib/tensorrt/test/quantization_test.py index e425a3674635650d7292ab072178e98932e6b824..ce1b25ebf3c52ac5710dea654134925bb5b6ceca 100644 --- a/tensorflow/contrib/tensorrt/test/quantization_test.py +++ b/tensorflow/contrib/tensorrt/test/quantization_test.py @@ -60,9 +60,9 @@ def _GetParams(add_quantization_nodes, dtype=dtypes.float32): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], - input_dims=[input_dims], + input_dims=[[input_dims]], output_names=[output_name], - expected_output_dims=[(8, 1)]) + expected_output_dims=[[[8, 1]]]) class QuantizationMissingAllRangesTest(trt_test.TfTrtIntegrationTestBase): diff --git a/tensorflow/contrib/tensorrt/test/rank_two_test.py b/tensorflow/contrib/tensorrt/test/rank_two_test.py index 563232fc12675d9e1b32b7ab461591af57beadb9..97159bb008068efbbcdb0fc6844890a42a08f46c 100644 --- a/tensorflow/contrib/tensorrt/test/rank_two_test.py +++ b/tensorflow/contrib/tensorrt/test/rank_two_test.py @@ -63,9 +63,9 @@ class RankTwoTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=input_names, - input_dims=input_dims, + input_dims=[input_dims], output_names=[output_name], - expected_output_dims=[tuple(input_dims[1])]) + expected_output_dims=[[input_dims[1]]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" @@ -80,12 +80,6 @@ class RankTwoTest(trt_test.TfTrtIntegrationTestBase): ], } - def ShouldRunTest(self, run_params): - """Whether to run the test.""" - # TODO(aaroey): Trt 4.0 forbids conversion for tensors with rank <3 in int8 - # mode, which is a bug. Re-enable this when trt library is fixed. - return not trt_test.IsQuantizationMode(run_params.precision_mode) - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py b/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py index 207944468ab0b038abfe01f0096d7dc220d064ed..7fb2cbde07c4987d925e9abc915ede52119ec6df 100644 --- a/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py +++ b/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py @@ -72,9 +72,9 @@ class ReshapeTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], - input_dims=[input_dims], + input_dims=[[input_dims]], output_names=[output_name], - expected_output_dims=[tuple(input_dims)]) + expected_output_dims=[[input_dims]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" @@ -129,9 +129,9 @@ class TransposeTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], - input_dims=[input_dims], + input_dims=[[input_dims]], output_names=[output_name], - expected_output_dims=[(24, 100, 2, 24)]) + expected_output_dims=[[[24, 100, 2, 24]]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py index d26f26008635733c6c364a98b72b88c1e552f5fe..090aa8bdb0487973e186631af3b4edac48096a5f 100644 --- a/tensorflow/contrib/tensorrt/test/test_tftrt.py +++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py @@ -191,7 +191,7 @@ def user(multi_engine, minimum_segment_size=2, # minimum number of nodes in an engine is_dynamic_op=False, maximum_cached_engines=1, - cached_engine_batch_sizes=[]) + cached_engine_batches=[]) o1 = run_graph(orig_graph, dummy_input) o2 = run_graph(trt_graph, dummy_input) o3 = run_graph(trt_graph, dummy_input) @@ -206,7 +206,7 @@ def user(multi_engine, minimum_segment_size=2, # minimum number of nodes in an engine is_dynamic_op=False, maximum_cached_engines=1, - cached_engine_batch_sizes=[]) + cached_engine_batches=[]) int8_calib_gdef = trt.create_inference_graph( input_graph_def=orig_graph, outputs=["output"], @@ -216,7 +216,7 @@ def user(multi_engine, minimum_segment_size=2, # minimum number of nodes in an engine is_dynamic_op=False, maximum_cached_engines=1, - cached_engine_batch_sizes=[]) + cached_engine_batches=[]) o4 = run_graph(fp16_graph, dummy_input) _ = run_calibration(int8_calib_gdef, dummy_input) int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef) diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py index 495a9391a1e818a6078988161c9bf72f6143737f..9a00cdb11a0f98d9b9be0d8e88a79cf45a8a7e3a 100644 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py @@ -25,10 +25,10 @@ import warnings import numpy as np import six -from tensorflow.contrib.tensorrt.python import trt_convert # pylint: disable=unused-import -from tensorflow.contrib.tensorrt.python.ops import trt_engine_op +from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops # pylint: enable=unused-import +from tensorflow.contrib.tensorrt.python import trt_convert from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.framework import dtypes @@ -39,9 +39,19 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import math_ops from tensorflow.python.platform import tf_logging as logging -TfTrtIntegrationTestParams = namedtuple("TfTrtIntegrationTestParams", [ - "gdef", "input_names", "input_dims", "output_names", "expected_output_dims" -]) +TfTrtIntegrationTestParams = namedtuple( + "TfTrtIntegrationTestParams", + [ + "gdef", + # A list of names of the input placeholder nodes. + "input_names", + # A list of list of output shapes of the input placeholder nodes. + "input_dims", + # A list of names of the output identity nodes. + "output_names", + # A list of list of expected output shapes of the output identity nodes. + "expected_output_dims" + ]) RunParams = namedtuple("RunParams", [ "use_optimizer", "precision_mode", "dynamic_engine", "test_name", @@ -51,7 +61,7 @@ RunParams = namedtuple("RunParams", [ ConversionParams = namedtuple("ConversionParams", [ "max_batch_size", "max_workspace_size_bytes", "precision_mode", "minimum_segment_size", "is_dynamic_op", "maximum_cached_engines", - "cached_engine_batch_sizes", "rewriter_config", "use_calibration" + "cached_engine_batches", "rewriter_config", "use_calibration" ]) PRECISION_MODES = ["FP32", "FP16", "INT8"] @@ -159,16 +169,24 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def GetConversionParams(self, run_params): """Return a ConversionParams for test.""" + batch_list = [] + for dims_list in self._GetParamsCached().input_dims: + assert dims_list + # Each list of shapes should have same batch size. + input_batches = [dims[0] for dims in dims_list] + assert max(input_batches) == min(input_batches) + batch_list.append(input_batches[0]) return ConversionParams( - max_batch_size=max([ - dims[0] for dims in self._GetParamsCached().input_dims if len(dims) - ]), + # We use the minimum of all the batch sizes, so when multiple different + # input shapes are provided it'll always create new engines in the + # cache, and we can therefore test the cache behavior. + max_batch_size=min(batch_list), max_workspace_size_bytes=1 << 25, precision_mode=run_params.precision_mode, minimum_segment_size=2, is_dynamic_op=run_params.dynamic_engine, maximum_cached_engines=1, - cached_engine_batch_sizes=None, + cached_engine_batches=None, rewriter_config=None, use_calibration=run_params.use_calibration) @@ -239,8 +257,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def _GetConfigProto(self, run_params, graph_state): """Get config proto based on specific settings.""" + conversion_params = self.GetConversionParams(run_params) if graph_state != GraphState.ORIGINAL and run_params.use_optimizer: - conversion_params = self.GetConversionParams(run_params) rewriter_cfg = trt_convert.get_tensorrt_rewriter_config( conversion_params.rewriter_config, conversion_params.max_batch_size, conversion_params.max_workspace_size_bytes, @@ -248,12 +266,15 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): conversion_params.minimum_segment_size, conversion_params.is_dynamic_op, conversion_params.maximum_cached_engines, - conversion_params.cached_engine_batch_sizes, + conversion_params.cached_engine_batches, conversion_params.use_calibration) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) else: graph_options = config_pb2.GraphOptions() + if conversion_params.rewriter_config is not None: + graph_options.rewrite_options.CopyFrom( + conversion_params.rewriter_config) config = config_pb2.ConfigProto( gpu_options=self._GetGPUOptions(), graph_options=graph_options) @@ -280,13 +301,16 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def _RunGraph(self, run_params, gdef, - input_data, + inputs_data, config, graph_state, num_runs=2): """Run given graphdef multiple times.""" params = self._GetParamsCached() - assert len(params.input_names) == len(input_data) + for current_input_data in inputs_data: + assert len(params.input_names) == len(current_input_data) + + vals = [] g = ops.Graph() with g.as_default(): io_ops = importer.import_graph_def( @@ -294,43 +318,48 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): return_elements=params.input_names + params.output_names, name="") inputs = [op.outputs[0] for op in io_ops[:len(params.input_names)]] - assert len(inputs) == len(input_data) + for current_input_data in inputs_data: + assert len(inputs) == len(current_input_data) outputs = [op.outputs[0] for op in io_ops[len(params.input_names):]] - with self.test_session( - graph=g, config=config, use_gpu=True, force_gpu=True) as sess: - val = None - # Defaults to 2 runs to verify result across multiple runs is same. - for _ in range(num_runs): - self._PrepareRun(graph_state) - new_val = sess.run( - outputs, {inputs[i]: input_data[i] for i in range(len(inputs))}) - output_len = len(params.expected_output_dims) - self.assertEqual(output_len, len(new_val)) - for i in range(output_len): - self.assertEqual(params.expected_output_dims[i], new_val[i].shape) - if val is not None: - self.assertAllClose(val, new_val, atol=1.e-06, rtol=1.e-06) - val = new_val - self.VerifyRun(run_params, graph_state) - return val + with self.test_session( + graph=g, config=config, use_gpu=True, force_gpu=True) as sess: + # Run for each input(s) shape + for shape_index in range(len(inputs_data)): + val = None + # Defaults to 2 runs to verify result across multiple runs is same. + for _ in range(num_runs): + self._PrepareRun(graph_state) + new_val = sess.run(outputs, { + inputs[i]: inputs_data[shape_index][i] + for i in range(len(inputs)) + }) + output_len = len(params.expected_output_dims[shape_index]) + self.assertEqual(output_len, len(new_val)) + for i in range(output_len): + self.assertEqual( + list(params.expected_output_dims[shape_index][i]), + list(new_val[i].shape)) + if val is not None: + self.assertAllClose(val, new_val, atol=1.e-06, rtol=1.e-06) + val = new_val + self.VerifyRun(run_params, graph_state) + vals.append(val) + return vals # Use real data that is representative of the inference dataset # for calibration. For this test script it is random data. - def _RunCalibration(self, run_params, gdef, input_data, config): + def _RunCalibration(self, run_params, gdef, inputs_data, config): """Run calibration on given graph.""" return self._RunGraph( - run_params, gdef, input_data, config, GraphState.CALIBRATE, num_runs=5) + run_params, gdef, inputs_data, config, GraphState.CALIBRATE, num_runs=5) - def _GetTrtGraphDef(self, run_params, gdef): + def _GetTrtGraphDef(self, run_params, graph_state, gdef): """Return trt converted graphdef.""" params = self._GetParamsCached() conversion_params = self.GetConversionParams(run_params) logging.info(conversion_params) - config_for_trt = config_pb2.ConfigProto(gpu_options=self._GetGPUOptions()) - if conversion_params.rewriter_config is not None: - config_for_trt.graph_options.rewrite_options.CopyFrom( - conversion_params.rewriter_config) + config_for_trt = self._GetConfigProto(run_params, graph_state) return trt_convert.create_inference_graph( input_graph_def=gdef, outputs=params.input_names + params.output_names, @@ -340,7 +369,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): minimum_segment_size=conversion_params.minimum_segment_size, is_dynamic_op=conversion_params.is_dynamic_op, maximum_cached_engines=conversion_params.maximum_cached_engines, - cached_engine_batch_sizes=conversion_params.cached_engine_batch_sizes, + cached_engine_batches=conversion_params.cached_engine_batches, use_calibration=conversion_params.use_calibration, session_config=config_for_trt) @@ -474,26 +503,31 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): dtypes.as_dtype(node.attr["dtype"].type).as_numpy_dtype()) assert len(params.input_names) == len(input_dtypes) - input_data = [] - for i in range(len(params.input_names)): - dtype = input_dtypes[params.input_names[i]] - # Multiply the input by some constant to avoid all zeros input for integer - # types. - scale = 10.0 if np.issubdtype(dtype, np.integer) else 1.0 - dims = params.input_dims[i] - # TODO(laigd): add debug options. E.g. we can set the input data to be - # continuous natural numbers: - # seq = np.arange(np.prod(dims)) - # seq.resize(dims) - # input_data.append(scale * seq.astype(dtype)) - input_data.append((scale * np.random.random_sample(dims)).astype(dtype)) + inputs_data = [] + for inp in params.input_dims: + current_input_data = [] + for i in range(len(params.input_names)): + dtype = input_dtypes[params.input_names[i]] + # Multiply the input by some constant to avoid all zeros input for + # integer types. + scale = 10.0 if np.issubdtype(dtype, np.integer) else 1.0 + dims = inp[i] + # TODO(laigd): add debug options. E.g. we can set the input data to be + # continuous natural numbers: + # seq = np.arange(np.prod(dims)) + # seq.resize(dims) + # input_data.append(scale * seq.astype(dtype)) + current_input_data.append( + (scale * np.random.random_sample(dims)).astype(dtype)) + inputs_data.append(current_input_data) + self._VerifyGraphDef(run_params, input_gdef, GraphState.ORIGINAL) # Get reference result without running trt. config_no_trt = self._GetConfigProto(run_params, GraphState.ORIGINAL) logging.info("Running original graph w/o trt, config:\n%s", str(config_no_trt)) - ref_result = self._RunGraph(run_params, input_gdef, input_data, + ref_result = self._RunGraph(run_params, input_gdef, inputs_data, config_no_trt, GraphState.ORIGINAL) # Run calibration if necessary. @@ -503,12 +537,13 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): calib_config = self._GetConfigProto(run_params, GraphState.CALIBRATE) logging.info("Running calibration graph, config:\n%s", str(calib_config)) if run_params.use_optimizer: - result = self._RunCalibration(run_params, input_gdef, input_data, + result = self._RunCalibration(run_params, input_gdef, inputs_data, calib_config) else: - calib_gdef = self._GetTrtGraphDef(run_params, input_gdef) + calib_gdef = self._GetTrtGraphDef(run_params, GraphState.CALIBRATE, + input_gdef) self._VerifyGraphDef(run_params, calib_gdef, GraphState.CALIBRATE) - result = self._RunCalibration(run_params, calib_gdef, input_data, + result = self._RunCalibration(run_params, calib_gdef, inputs_data, calib_config) infer_gdef = trt_convert.calib_graph_to_infer_graph( calib_gdef, run_params.dynamic_engine) @@ -527,10 +562,11 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): logging.info("Running final inference graph, config:\n%s", str(infer_config)) if not run_params.use_optimizer: - infer_gdef = self._GetTrtGraphDef(run_params, infer_gdef) + infer_gdef = self._GetTrtGraphDef(run_params, GraphState.INFERENCE, + infer_gdef) self._VerifyGraphDef(run_params, infer_gdef, GraphState.INFERENCE) - result = self._RunGraph(run_params, infer_gdef, input_data, infer_config, + result = self._RunGraph(run_params, infer_gdef, inputs_data, infer_config, GraphState.INFERENCE) self.assertAllClose( ref_result, diff --git a/tensorflow/contrib/tensorrt/test/topk_test.py b/tensorflow/contrib/tensorrt/test/topk_test.py new file mode 100644 index 0000000000000000000000000000000000000000..633a51982b9a6acf1926033628793c1edbd2d118 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/topk_test.py @@ -0,0 +1,58 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class TopKTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Testing Top-K in TF-TRT conversion.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [100, 100] + k = 5 + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + k_tensor = constant_op.constant(k, dtype=dtypes.int32, name="Const") + values, indices = nn_ops.top_k(x, k_tensor, name="TopK") + values = array_ops.identity(values, name="output_values") + indices = array_ops.identity(indices, name="output_indices") + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[[input_dims]], + output_names=["output_values", "output_indices"], + expected_output_dims=[[[100, k], [100, k]]]) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return {"TRTEngineOp_0": ["Const", "TopK"]} + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/contrib/tensorrt/test/unary_test.py index b6e5e32db1236684a06c2d44298b9a3d39667152..497ea2848aae42a61db4f8f5a5c973525d5892d9 100644 --- a/tensorflow/contrib/tensorrt/test/unary_test.py +++ b/tensorflow/contrib/tensorrt/test/unary_test.py @@ -100,9 +100,9 @@ class UnaryTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name, input2_name], - input_dims=[input_dims, input2_dims], + input_dims=[[input_dims, input2_dims]], output_names=[output_name], - expected_output_dims=[(12, 5, 8, 12)]) + expected_output_dims=[[[12, 5, 8, 12]]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py index b29626d2c28b4def716aef9e2703b669b5e46374..b5fed73d2d75271e2c5c533670923d42f233e80b 100644 --- a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py +++ b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py @@ -70,9 +70,9 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], - input_dims=[input_dims], + input_dims=[[input_dims]], output_names=[output_name], - expected_output_dims=[(5, 6, 2, 2)]) + expected_output_dims=[[[5, 6, 2, 2]]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_test.py index 9b0b189626050f678c71e9abbf7eb5296440d879..307128f1a89c46d63e851b6a7cd2d6abe7e39ff8 100644 --- a/tensorflow/contrib/tensorrt/test/vgg_block_test.py +++ b/tensorflow/contrib/tensorrt/test/vgg_block_test.py @@ -61,9 +61,9 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase): return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=[input_name], - input_dims=[input_dims], + input_dims=[[input_dims]], output_names=[output_name], - expected_output_dims=[(5, 2, 2, 6)]) + expected_output_dims=[[[5, 2, 2, 6]]]) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i index 6ea15fb8eff13663625420288a37ba002d57fa47..c12895c730047898f366bf651c798c1f1c5b93f7 100644 --- a/tensorflow/contrib/tensorrt/trt_conversion.i +++ b/tensorflow/contrib/tensorrt/trt_conversion.i @@ -99,9 +99,9 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/util/stat_summarizer.h" -#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/test/utils.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h" %} %ignoreall diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD index 57797214d1684550aa7ad2664b71d22b504f70ed..e10be88ece8ebba9635af955b3c3410f29e5503c 100644 --- a/tensorflow/contrib/timeseries/examples/BUILD +++ b/tensorflow/contrib/timeseries/examples/BUILD @@ -105,6 +105,7 @@ py_binary( data = ["data/multivariate_periods.csv"], srcs_version = "PY2AND3", tags = ["no_pip"], + visibility = ["//visibility:public"], deps = select({ ":empty_condition": [], "//conditions:default": [], @@ -113,6 +114,7 @@ py_binary( "//tensorflow:tensorflow_py", "//tensorflow/contrib/timeseries/python/timeseries:estimators", "//tensorflow/contrib/timeseries/python/timeseries:model", + "//tensorflow/contrib/timeseries/python/timeseries:state_management", ], ) diff --git a/tensorflow/contrib/timeseries/examples/predict_test.py b/tensorflow/contrib/timeseries/examples/predict_test.py index 678fd71cd8b94ee0be46e10a9a673de55bd44215..b353f85cb5df0cf961d1900b241e4fa1a84a24b4 100644 --- a/tensorflow/contrib/timeseries/examples/predict_test.py +++ b/tensorflow/contrib/timeseries/examples/predict_test.py @@ -43,10 +43,6 @@ class PeriodTrendExampleTest(test.TestCase): self.assertAllEqual([700], mean.shape) self.assertAllEqual([700], upper_limit.shape) self.assertAllEqual([700], lower_limit.shape) - # Check that variance hasn't blown up too much. This is a relatively good - # indication that training was successful. - self.assertLess(upper_limit[-1] - lower_limit[-1], - 1.5 * (upper_limit[0] - lower_limit[0])) def test_ar(self): (times, observed, all_times, mean, @@ -55,7 +51,6 @@ class PeriodTrendExampleTest(test.TestCase): self.assertAllEqual(all_times.shape, mean.shape) self.assertAllEqual(all_times.shape, upper_limit.shape) self.assertAllEqual(all_times.shape, lower_limit.shape) - self.assertLess((upper_limit - lower_limit).mean(), 4.) if __name__ == "__main__": diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index 4b90b596b28efec83aa349782c4874d79b6817c7..2a22295197dc225cefbedf2736adeea5491a9fc2 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -281,6 +281,7 @@ py_library( "input_pipeline.py", ], srcs_version = "PY2AND3", + visibility = ["//visibility:public"], deps = [ ":feature_keys", ":model_utils", @@ -361,9 +362,10 @@ py_library( srcs_version = "PY2AND3", deps = [ ":feature_keys", + ":math_utils", ":model", ":model_utils", - "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:constant_op", diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py index bcadf4094e1e79fff1685515f2bde0b88f717cac..3626701d24163ef52564b42d8a630bd9c5a788eb 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py @@ -18,9 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import distributions - from tensorflow.contrib.rnn.python.ops import lstm_ops +from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.contrib.timeseries.python.timeseries import model from tensorflow.contrib.timeseries.python.timeseries import model_utils from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures @@ -462,11 +461,12 @@ class ARModel(model.TimeSeriesModel): if self.loss == ARModel.NORMAL_LIKELIHOOD_LOSS: covariance = prediction_ops["covariance"] sigma = math_ops.sqrt(gen_math_ops.maximum(covariance, 1e-5)) - normal = distributions.Normal(loc=targets, scale=sigma) - loss_op = -math_ops.reduce_sum(normal.log_prob(prediction)) + loss_op = -math_ops.reduce_sum( + math_utils.normal_log_prob(targets, sigma, prediction)) else: assert self.loss == ARModel.SQUARED_LOSS, self.loss - loss_op = math_ops.reduce_sum(math_ops.square(prediction - targets)) + loss_op = math_ops.reduce_sum( + math_ops.squared_difference(prediction, targets)) loss_op /= math_ops.cast( math_ops.reduce_prod(array_ops.shape(targets)), loss_op.dtype) return loss_op @@ -965,16 +965,11 @@ class AnomalyMixtureARModel(ARModel): anomaly_variance = prediction_ops["anomaly_params"] anomaly_sigma = math_ops.sqrt( gen_math_ops.maximum(anomaly_variance, 1e-5)) - normal = distributions.Normal(loc=targets, scale=anomaly_sigma) - log_prob = normal.log_prob(prediction) + log_prob = math_utils.normal_log_prob(targets, anomaly_sigma, prediction) else: assert self._anomaly_distribution == AnomalyMixtureARModel.CAUCHY_ANOMALY anomaly_scale = prediction_ops["anomaly_params"] - cauchy = distributions.StudentT( - df=array_ops.ones([], dtype=anomaly_scale.dtype), - loc=targets, - scale=anomaly_scale) - log_prob = cauchy.log_prob(prediction) + log_prob = math_utils.cauchy_log_prob(targets, anomaly_scale, prediction) return log_prob def loss_op(self, targets, prediction_ops): @@ -983,8 +978,7 @@ class AnomalyMixtureARModel(ARModel): covariance = prediction_ops["covariance"] # Normal data log probability. sigma = math_ops.sqrt(gen_math_ops.maximum(covariance, 1e-5)) - normal1 = distributions.Normal(loc=targets, scale=sigma) - log_prob1 = normal1.log_prob(prediction) + log_prob1 = math_utils.normal_log_prob(targets, sigma, prediction) log_prob1 += math_ops.log(1 - self._anomaly_prior_probability) # Anomaly log probability. log_prob2 = self._anomaly_log_prob(targets, prediction_ops) diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py index aab330643862c1ccf073d2a0e34e1c475b1ec15f..b7375e5055e29efea3f23c3b9b9f3af59f45495b 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py @@ -21,6 +21,8 @@ from __future__ import print_function import collections import math +import numpy as np + from tensorflow.contrib import lookup from tensorflow.contrib.layers.python.layers import layers @@ -43,6 +45,32 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest +def normal_log_prob(loc, scale, x): + """Computes the Normal log pdf.""" + z = (x - loc) / scale + return -0.5 * (math_ops.square(z) + + np.log(2. * np.pi) + math_ops.log(scale)) + + +def cauchy_log_prob(loc, scale, x): + """Computes the Cauchy log pdf.""" + z = (x - loc) / scale + return (-np.log(np.pi) - math_ops.log(scale) - + math_ops.log1p(math_ops.square(z))) + + +def mvn_tril_log_prob(loc, scale_tril, x): + """Computes the MVN log pdf under tril scale. Doesn't handle batches.""" + x0 = x - loc + z = linalg_ops.matrix_triangular_solve( + scale_tril, x0[..., array_ops.newaxis])[..., 0] + log_det_cov = 2. * math_ops.reduce_sum(math_ops.log( + array_ops.matrix_diag_part(scale_tril)), axis=-1) + d = math_ops.cast(array_ops.shape(scale_tril)[-1], log_det_cov.dtype) + return -0.5 * (math_ops.reduce_sum(math_ops.square(z), axis=-1) + + d * np.log(2. * np.pi) + log_det_cov) + + def clip_covariance( covariance_matrix, maximum_variance_ratio, minimum_variance): """Enforce constraints on a covariance matrix to improve numerical stability. diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD index 125750e7639ad40c481472a93353e6fb7055be96..cf5e749042afd83f927a3d22edfd3a9538ab2ffd 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD @@ -78,7 +78,6 @@ py_library( srcs = ["kalman_filter.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/timeseries/python/timeseries:math_utils", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", @@ -235,7 +234,6 @@ py_library( srcs = ["filtering_postprocessor.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/timeseries/python/timeseries:math_utils", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py index e9e2ac0aaf4c4d6c41f5007662f261af3de9bbd1..3fa2fbd9f77cb887c30fde264815728ca345f45a 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py @@ -22,8 +22,6 @@ import abc import six -from tensorflow.contrib import distributions - from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.python.framework import dtypes @@ -91,10 +89,10 @@ def cauchy_alternative_to_gaussian(current_times, current_values, outputs): """ del current_times # unused cauchy_scale = math_utils.entropy_matched_cauchy_scale(outputs["covariance"]) - individual_log_pdfs = distributions.StudentT( - df=array_ops.ones([], dtype=current_values.dtype), + individual_log_pdfs = math_utils.cauchy_log_prob( loc=outputs["mean"], - scale=cauchy_scale).log_prob(current_values) + scale=cauchy_scale, + x=current_values) return math_ops.reduce_sum(individual_log_pdfs, axis=1) diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py index a614386121e000961bf8b32625a28e1251654320..c0ec797bc5b7c41ca996c807840ce38311201f87 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import distributions - from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.python.framework import dtypes @@ -137,9 +135,10 @@ class KalmanFilter(object): with ops.control_dependencies([non_negative_assert]): observation_covariance_cholesky = linalg_ops.cholesky( symmetrized_observation_covariance) - log_prediction_prob = distributions.MultivariateNormalTriL( - predicted_observation, observation_covariance_cholesky).log_prob( - observation) + log_prediction_prob = math_utils.mvn_tril_log_prob( + loc=predicted_observation, + scale_tril=observation_covariance_cholesky, + x=observation) (posterior_state, posterior_state_var) = self.posterior_from_prior_state( prior_state=estimated_state, diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index ec8a273ea89f0b94db7b602494ea76207be8c1a2..c1a36fecc25545c6611ea09190dd89a8e1d82afe 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -61,6 +61,7 @@ py_library( py_library( name = "tpu_estimator", srcs = [ + "python/tpu/_tpu_estimator_embedding.py", "python/tpu/error_handling.py", "python/tpu/tpu_config.py", "python/tpu/tpu_context.py", @@ -70,12 +71,17 @@ py_library( srcs_version = "PY2AND3", deps = [ ":async_checkpoint", + ":feature_column", + ":functional", + ":tpu_embedding", ":tpu_lib", + ":tpu_ordinal_selector_py", "//tensorflow/contrib/training:training_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:function", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", @@ -155,6 +161,25 @@ tf_gen_op_wrapper_py( ], ) +tf_custom_op_library( + name = "python/ops/_tpu_ordinal_selector.so", + srcs = ["ops/tpu_ordinal_selector_op.cc"], +) + +tf_custom_op_py_library( + name = "tpu_ordinal_selector_py", + srcs = ["ops/gen_tpu_ordinal_selector_op.py"], + dso = [":python/ops/_tpu_ordinal_selector.so"], + kernels = [ + ":tpu_ordinal_selector_op_op_lib", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":tpu_ordinal_selector_op", + ], +) + tf_gen_op_wrapper_py( name = "tpu_ordinal_selector_op", deps = [ @@ -242,7 +267,6 @@ py_library( "//learning/brain:__subpackages__", "//tensorflow:__subpackages__", "//third_party/cloud_tpu/models/keras_colab:__subpackages__", - "//third_party/cloud_tpu/models/mnist_keras:__subpackages__", "//third_party/cloud_tpu/models/resnet50_keras:__subpackages__", ], deps = [ @@ -298,6 +322,7 @@ py_library( "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/compiler:xla", "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", + "//tensorflow/contrib/tpu/proto:dynamic_padding_proto_py", "//tensorflow/contrib/tpu/proto:optimization_parameters_proto_py", "//tensorflow/contrib/tpu/proto:topology_proto_py", "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_py", @@ -337,13 +362,15 @@ py_library( tf_py_test( name = "datasets_test", + size = "medium", srcs = ["python/tpu/datasets_test.py"], additional_deps = [ "//tensorflow/python:client_testlib", ":datasets", ], - flaky = 1, # TODO(b/117363808): fails 1/1000 OSS runs grpc_enabled = True, + shard_count = 4, + tags = ["no_oss"], ) tf_py_test( @@ -430,7 +457,8 @@ py_library( srcs = ["python/tpu/tpu_embedding.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/tpu:tpu_ops", + ":tpu_lib", + ":tpu_ops", "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_py", "//tensorflow/python:array_ops", "//tensorflow/python:framework_for_generated_wrappers", diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc index 285e11d92de7a684ed87974414ec73c274cc7aa5..d4180d1a20bc59f3fbb37b2dbc67790ded9d2d90 100644 --- a/tensorflow/contrib/tpu/ops/replication_ops.cc +++ b/tensorflow/contrib/tpu/ops/replication_ops.cc @@ -31,6 +31,7 @@ REGISTER_OP("TPUReplicateMetadata") // Deprecated. Use num_cores_per_replica instead. .Attr("computation_shape: list(int) = []") .Attr("host_compute_core: list(string) = []") + .Attr("padding_map: list(string) = []") .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("TPUReplicatedInput") @@ -105,6 +106,7 @@ REGISTER_OP("TPUReplicate") .Attr("NumVariables: int >= 0") .Attr("Tguaranteed_constants: list(type) >= 0") .Attr("output_types: list(type) >= 0") + .Attr("padding_map: list(string) = []") .Input("inputs: Tinputs") .Input("broadcast_inputs: Tbroadcast_inputs") .Input("variables: NumVariables * resource") diff --git a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc index 0ef29bdf734467aa9dee5c157bc8d8a7e0a85f13..676aed0b7b651494eda80ff2d7c7c31097529590 100644 --- a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc +++ b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc @@ -37,18 +37,18 @@ namespace tensorflow { // pieces of the TF Graph. // 1. Pass this TPUEmbeddingConfiguration to tpu.initialize_system() as the // tpu_embedding_config parameter. -// 2. Use the TPUEmbeddingLoad Op to initialize the embedding tables in TPU +// 2. Use the LoadTPUEmbedding Ops to initialize the embedding tables in TPU // memories, sharded across the memories attached to each Host. -// 3. Use TPUEmbeddingEnqueueSparseBatch to provide the TPU with embedding +// 3. Use EnqueueTPUEmbeddingSparseBatch to provide the TPU with embedding // indices and aggregation weights. -// 4. TPUEmbeddingReceiveActivations returns a list of Tensors, containing the +// 4. RecvTPUEmbeddingActivations returns a list of Tensors, containing the // activations from each table specified in the configuration. // 5. TPUEmbeddingActivations, when used with appropriate Python libraries, // enables the automatic differentiation of models that use embeddings. -// 6. TPUEmbeddingSendGradients takes a list of Tensors (of the same shapes +// 6. SendTPUEmbeddingGradients takes a list of Tensors (of the same shapes // as those returned by TPUEmbeddingReceiveActivations) containing gradients // to use in updating the embedding tables. -// 7. Before saving a checkpoint, use the TPUEmbeddingRetrieve Op to update +// 7. Before saving a checkpoint, use the RetrieveTPUEmbedding Ops to update // the Graph's embedding table Variables from the updated tables in the // TPU memories. // @@ -455,20 +455,21 @@ REGISTER_OP("SendTPUEmbeddingGradients") return Status::OK(); }) .Doc(R"doc( -An op that performs gradient updates of embedding tables. - -The TensorList argument has the same length and shapes as the return value of -TPUEmbeddingReceiveActivations, but contains gradients of the model's loss -with respect to the embedding activations. The embedding tables are updated -from these gradients via the optimizer specified in the configuration given -to tpu.initialize_system. +An op that performs gradient updates of embedding tables using the specified +learning rates. inputs: A TensorList of gradients with which to update embedding tables. - It contains one tensor per embedding table in the model. -learning_rates: A list of float32 scalars, one for each embedding table, - containing the learning rates for each table when dynamic learning rate is - enabled through the OptimizationParameters in TPUEmbeddingConfiguration. - When the learning rate is constant, the list should be empty. + This argument has the same length and shapes as the return value of + RecvTPUEmbeddingActivations, but contains gradients of the model's loss + with respect to the embedding activations. The embedding tables are updated + from these gradients via the optimizer specified in the TPU embedding + configuration given to tpu.initialize_system. +learning_rates: A TensorList of float32 scalars, one for each dynamic learning + rate tag: see the comments in + //third_party/tensorflow/contrib/tpu/proto/optimization_parameters.proto. + Multiple tables can share the same dynamic learning rate tag as specified + in the configuration. If the learning rates for all tables are constant, + this list should be empty. config: Serialized TPUEmbeddingConfiguration proto. )doc"); diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py index f27ae38e0434991da7475e631be1c6cb4a463118..807cf26fe983b4ebe17695d6f4f90ecfc0e0cbf5 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py @@ -33,7 +33,7 @@ setup( long_description='Tools for capture TPU profile', url='https://www.tensorflow.org/tfrc/', author='Google Inc.', - author_email='opensource@google.com', + author_email='packages@tensorflow.org', packages=['cloud_tpu_profiler'], package_data={ 'cloud_tpu_profiler': ['data/*'], diff --git a/tensorflow/contrib/tpu/profiler/trace_events.proto b/tensorflow/contrib/tpu/profiler/trace_events.proto index cb2b9162677a0ebe8240a98671b1cabc1cee0c9f..96c4784c691d8f34cf8715cdc0ed9886412f5f90 100644 --- a/tensorflow/contrib/tpu/profiler/trace_events.proto +++ b/tensorflow/contrib/tpu/profiler/trace_events.proto @@ -56,4 +56,7 @@ message TraceEvent { // The duration of the event in picoseconds if applicable. // Events without duration are called instant events. uint64 duration_ps = 10; + + // Extra arguments that will be displayed in trace view. + map args = 11; } diff --git a/tensorflow/contrib/tpu/proto/BUILD b/tensorflow/contrib/tpu/proto/BUILD index c20cab844cfaf083be2702a29ac2a152c7b72c2a..ea98ee25c89e1b7bef39276bae5c98bf382dbd7f 100644 --- a/tensorflow/contrib/tpu/proto/BUILD +++ b/tensorflow/contrib/tpu/proto/BUILD @@ -49,6 +49,15 @@ tf_proto_library( visibility = ["//visibility:public"], ) +tf_proto_library( + name = "dynamic_padding_proto", + srcs = [ + "dynamic_padding.proto", + ], + cc_api_version = 2, + visibility = ["//visibility:public"], +) + tf_proto_library_py( name = "compilation_result_proto", srcs = [ diff --git a/tensorflow/contrib/tpu/proto/dynamic_padding.proto b/tensorflow/contrib/tpu/proto/dynamic_padding.proto new file mode 100644 index 0000000000000000000000000000000000000000..c9ebf181169a583d774ef77ca0b8c243ce733615 --- /dev/null +++ b/tensorflow/contrib/tpu/proto/dynamic_padding.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +option cc_enable_arenas = true; + +package tensorflow.tpu; + +// A mapping between the dynamic shape dimension of an input and the arg that +// represents the real shape. +message PaddingMap { + // Input arg index with dynamic shapes. + int32 arg_index = 1; + + // The dynamic shape dimension index. + int32 shape_index = 2; + + // The arg index that dynamic dimension maps to, which represents the value + // of the real shape. + int32 padding_arg_index = 3; +} diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto index aae1ab1d37a166303883e3a07a7a01efe2feab51..bc50c613f3d2a09f9e51353fab4938055549a4cd 100644 --- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto +++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto @@ -9,9 +9,38 @@ message ClippingLimits { google.protobuf.FloatValue upper = 2; // +inf if not set } -// Get the learning rate from the parameters of the SendTPUEmbeddingGradients -// op. +// Dynamic learning rate specification in the TPUEmbeddingConfiguration. The +// actual learning rates are provided as a scalar input list to the +// SendTPUEmbeddingGradients Op indexed by their tag specified through the +// following proto. message DynamicLearningRate { + // For tables where learning rates are dynamically computed and communicated + // to the TPU embedding program, a tag must be specified for the learning + // rate. + // + // The tag must be a non-negative integer. The total number of unique tags + // must be less than or equal to the number of tables in the TPU embedding + // configuration (a table does not specify any tag if it uses a constant + // learning rate, and specifies exactly one tag if it uses dynamic learning + // rates). + // + // All tags in the range [0, number_of_unique_tags) must be present in the TPU + // embedding configuration, i.e. a tag cannot be skipped if a different tag + // numerically greater than it is used in the configuration. + // + // If multiple tables specify the same tag, they *MUST* have + // the same dynamic learning rate, for example, their dynamic learning rate + // could be computed by the same TensorFlow sub-graph. The partitioning of the + // embedding layer would be more optimal if the number_of_unique_tags is as + // *LOW* as possible, i.e., if many tables share the same tag. + // + // The learning_rate input of the SendTPUEmbeddingGradients op is used to + // communicate dynamic learning rates to the TPU embedding program. + // The learning_rate input is a list of scalars where the size of the list is + // equal to the number of unique tags. The learning rate associated with a + // particular tag is specified by populating its corresponding index in the + // list of learning_rate scalars. + int32 tag = 1; } // Source of learning rate to use. @@ -186,7 +215,8 @@ message OptimizationParameters { } // Specification of an optimization algorithm's state variables (both the main -// value vector and any extra accumulators, etc.). +// value vector and any extra accumulators, etc.). This proto is only used +// internally by the TPU software and is not exposed directly to the TF model. message StateVariableSpecification { // Parameter name for the state variable. string name = 1; @@ -194,6 +224,20 @@ message StateVariableSpecification { // A normal state variable that should be saved and restored in checkpoints // and used as an input or output to non-debug TensorFlow ops. message UserDefined { + // For padding embedding rows, this field specifies the initial value to be + // used. Separate initial values need to be specified for the embeddings and + // any extra accumulators. The initial values should be specified so as to + // maintain two invariants during model training: + // (1) The embedding vector multiplied by zero returns a vector containing + // all zeros. To maintain this invariant, the embedding values should + // never be NaNs or +-infinity. + // (2) Repeatedly applying the optimizer using a gradient vector of all + // zeros does not cause the embeddings or slot variables to become NaNs + // or +-infinity. + // The padding row is looked up when no embedding IDs are present for a + // feature. The semantics of embedding lookup dictate that the output must + // be zero under this scenario. + double padding_initial_value = 1; } // A state variable that should be filled with a constant and normally hidden diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index 6a6eba282a12d68cc3cd4e46a46a1b4190fb737b..9260e7b8a800c3bf160923af95867d44342000a3 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -217,13 +217,19 @@ if platform.system() != "Windows": Args: inputs: A TensorList of gradients with which to update embedding tables. - Contains one tensor per embedding table in the model. + This argument has the same length and shapes as the return value of + RecvTPUEmbeddingActivations, but contains gradients of the model's + loss with respect to the embedding activations. The embedding tables + are updated from these gradients via the optimizers specified in the + TPU embedding configuration given to tpu.initialize_system. config: Serialized TPUEmbeddingConfiguration proto. - learning_rates: A TensorList of float32 scalars, one for each embedding - table, containing the learning rates for each table when dynamic - learning rate is enabled through the OptimizationParameters in - TPUEmbeddingConfiguration. When the learning rate is constant, the list - should be empty (optional). + learning_rates: A TensorList of float32 scalars, one for each dynamic + learning rate tag: see the comments in + //third_party/tensorflow/contrib/tpu/proto/ + optimization_parameters.proto. + Multiple tables can share the same dynamic learning rate tag as + specified in the configuration. If the learning rates for all tables + are constant, this list should be empty. name: A name for the operation (optional). Returns: @@ -337,9 +343,8 @@ if platform.system() != "Windows": Args: sample_indices: A list of rank 1 Tensors specifying the training example to which the corresponding embedding_indices and aggregation_weights - values - belong. It corresponds to sp_ids.indices[:,0] in - embedding_lookup_sparse(). + values belong. It corresponds to sp_ids.indices[:,0] in + embedding_lookup_sparse(). embedding_indices: A list of rank 1 Tensors, indices into the embedding tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). aggregation_weights: A list of rank 1 Tensors containing per training diff --git a/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py b/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..6ce96e5bcdbe5777f68eb969be46423b5b3410cb --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py @@ -0,0 +1,273 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== +"""Tooling for support TPU embedding in TPUEstimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.contrib.tpu.python.tpu import feature_column as tpu_fc +from tensorflow.contrib.tpu.python.tpu import tpu_embedding +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.feature_column import feature_column as core_fc +from tensorflow.python.feature_column import feature_column_lib as core_fc_lib + +# pylint: disable=protected-access +_TPU_EMBEDDING_COLUMN_CLASSES = (tpu_fc._TPUEmbeddingColumn, + tpu_fc._TPUSharedEmbeddingColumn) +_EMBEDDING_COLUMN_CLASSES = (core_fc._EmbeddingColumn, + core_fc_lib.EmbeddingColumn, + core_fc._SharedEmbeddingColumn) +_SUPPORTED_FEATURE_COLUMNS = (core_fc._NumericColumn, core_fc_lib.NumericColumn) + +# pylint: enable=protected-access + + +def get_tpu_embedding_config_from_feature_columns(feature_columns): + """Create configs for TPUEmbedding from a list of feature columns. + + This function will place one embedding tensor per table and the return is + intended to be used as input to TPUEmbedding. + + Args: + feature_columns: a list of supported feature columns. + + Returns: + A pair of dicts, the first maps tables to their config, the second maps + features to tables. + """ + + allowed = (tpu_fc._TPUEmbeddingColumn, tpu_fc._TPUSharedEmbeddingColumn) # pylint: disable=protected-access + + for column in feature_columns: + if not isinstance(column, allowed): + raise TypeError( + 'Unsupported feature column {}. Supported types are {}.'.format( + type(column), allowed)) + + table_to_config = {} + feature_to_table = {} + for column in feature_columns: + feature_name = column.get_feature_key_name() + table_name = 'tbl_{}'.format(column.get_embedding_var_name()) + if feature_name in feature_to_table: + raise ValueError( + 'Feature column {} is used with multiple embeddings and this is ' + 'not supported.'.format(feature_name)) + feature_to_table[feature_name] = table_name + vocabulary_size, dimension = column.get_embedding_table_size() + table_to_config[table_name] = tpu_embedding.TableConfig( + vocabulary_size=vocabulary_size, + dimension=dimension, + initializer=column.get_initializer(), + combiner=column.get_combiner()) + + return table_to_config, feature_to_table + + +def _get_tpu_embedding_optimization_parameters(embedding_config_spec): + """Get tpu_embedding._OptimizationParameters from EmbeddingConfigSpec.""" + if embedding_config_spec.optimizer_type == 'adagrad': + return tpu_embedding.AdagradParameters( + embedding_config_spec.learning_rate, + embedding_config_spec.adagrad_initial_accumulator, + embedding_config_spec.use_gradient_accumulation) + elif embedding_config_spec.optimizer_type == 'sgd': + return tpu_embedding.StochasticGradientDescentParameters( + embedding_config_spec.learning_rate, + embedding_config_spec.use_gradient_accumulattion) + elif embedding_config_spec.optimizer_type == 'adam': + return tpu_embedding.AdamParameters( + embedding_config_spec.learning_rate, + embedding_config_spec.adam_parameters.beta1, + embedding_config_spec.adam_parameters.beta2, + embedding_config_spec.adam_parameters.epsilon, + use_gradient_accumulation=embedding_config_spec + .use_gradient_accumulation) + else: + raise ValueError('optimizer_type must be adagrad or sgd or adam for now.') + + +AdamParameters = collections.namedtuple('AdamParameters', + ['beta1', 'beta2', 'epsilon']) + + +# TODO(shizhiw): Improve the API to support more optimizer parameters in API. +class EmbeddingConfigSpec( + collections.namedtuple('EmbeddingConfigSpec', [ + 'feature_columns', 'learning_rate', 'optimizer_type', + 'adagrad_initial_accumulator', 'clipping_limit', + 'use_gradient_accumulation', 'adam_parameters' + ])): + """Class to keep track of embedding config specification.""" + + def __new__(cls, + feature_columns, + learning_rate, + optimizer_type='adagrad', + adagrad_initial_accumulator=None, + clipping_limit=None, + use_gradient_accumulation=False, + adam_parameters=None): + """Creates an EmbeddingConfigSpec instance. + + Args: + feature_columns: All `FeatureColumn`s used by model. + learning_rate: embedding optimizer learning rate. + optimizer_type: (String) Name of the optimizer for embedding gradients + updates. Must be either 'adagrad' ( `tf.train.AdagradOptimizer`, default + value), 'sgd' (`tf.train.GradientDescentOptimizer`), or 'adam' + (`tf.contrib.opt.LazyAdamOptimizer`) for lazy Adam. This optimizer will + be applied to all embedding variables specified by `feature_columns`. + adagrad_initial_accumulator: Initial accumulator for Adagrad. Used when + optimizer_type is 'adagrad'. Default is `0.1`. + clipping_limit: (Optional) Clipping limit (absolute value). + use_gradient_accumulation: (Experimental) Whether to accumulate the + gradients across TPU embedding mini-batches. Gradient accumulation does + not affect SGD and therefore this is applicable only for Adagrad. + adam_parameters: AdamParameters. Used when optimizer_type is 'adam'. + Default is 0.9 for beta1, 0.999 for beta2 and 1e-8 for epsilon. + + Returns: + An EmbeddingConfigSpec instance. + + Raises: + ValueError: If the feature_columns are not specified. + TypeError: If the feature columns are not of ths correct type (one of + _SUPPORTED_FEATURE_COLUMNS, _TPU_EMBEDDING_COLUMN_CLASSES OR + _EMBEDDING_COLUMN_CLASSES). + ValueError: If use_gradient_accumulation is True for SGD. + ValueError: If `optimizer_type` is not one of "adagrad" or "sgd" or + "adam". + """ + if not feature_columns: + raise ValueError('`feature_columns` cannot be `None` or empty.') + + # It is unknown at this moment, whether the TPUEstimator is running in CPU + # or TPU mode. So allow non-TPU embedding columns also. + supported_classes = tuple( + list(_SUPPORTED_FEATURE_COLUMNS) + list(_TPU_EMBEDDING_COLUMN_CLASSES) + + list(_EMBEDDING_COLUMN_CLASSES)) + + for column in feature_columns: + if not isinstance(column, supported_classes): + raise TypeError( + 'All feature columns must be supported types in {}. Got {}'.format( + supported_classes, type(column))) + + if optimizer_type == 'adagrad': + if adagrad_initial_accumulator is None: + adagrad_initial_accumulator = 0.1 + if adagrad_initial_accumulator <= 0: + raise ValueError('Adagrad initial_accumulator must be positive') + elif optimizer_type == 'sgd': + if use_gradient_accumulation: + raise ValueError('Gradient accumulation makes sense for Adagrad only.') + elif optimizer_type == 'adam': + if adam_parameters is None: + adam_parameters = AdamParameters(0.9, 0.999, 1e-8) + if adam_parameters.beta1 < 0. or adam_parameters.beta1 >= 1.: + raise ValueError('beta1 must be between 0. and 1; got {}.'.format( + adam_parameters.beta1)) + if adam_parameters.beta2 < 0. or adam_parameters.beta2 >= 1.: + raise ValueError('beta2 must be between 0. and 1; got {}.'.format( + adam_parameters.beta2)) + if adam_parameters.epsilon <= 0.: + raise ValueError('epsilon must be positive; got {}.'.format( + adam_parameters.epsilon)) + else: + raise ValueError('optimizer_type must be adagrad or sgd or adam for now.') + + return super(EmbeddingConfigSpec, cls).__new__( + cls, + feature_columns=feature_columns, + learning_rate=learning_rate, + optimizer_type=optimizer_type, + adagrad_initial_accumulator=adagrad_initial_accumulator, + clipping_limit=clipping_limit, + use_gradient_accumulation=use_gradient_accumulation, + adam_parameters=adam_parameters) + + +class EmbeddingConfig(object): + """This is the internal immutable object for embedding config. + + `_EmbeddingConfig` is responsible to _translate_ user provided + `EmbeddingConfigSpec` to internal data structures, mostly constructor + arguments of `TPUEmbedding`. + """ + + def __init__(self, embedding_config_spec, train_batch_size, eval_batch_size, + num_hosts, num_cores, master): + self._embedding_config_spec = embedding_config_spec + self._train_batch_size = train_batch_size + self._eval_batch_size = eval_batch_size + self._num_hosts = num_hosts + self._num_cores = num_cores + self._master = master + + self._table_to_config_dict, self._feature_to_table_dict = ( + get_tpu_embedding_config_from_feature_columns( + embedding_config_spec.feature_columns)) + self._optimization_parameters = _get_tpu_embedding_optimization_parameters( + self._embedding_config_spec) + self._mode_to_tpu_embedding_dict = {} + + def has_embedding_tables(self): + return bool(self._table_to_config_dict) + + def _create_tpu_embedding(self, mode): + """Create tpu_embedding.TPUEmbedding based on mode.""" + if mode == model_fn_lib.ModeKeys.TRAIN: + batch_size = self._train_batch_size + else: + batch_size = self._eval_batch_size + + if mode == model_fn_lib.ModeKeys.TRAIN: + tpu_embedding_mode = tpu_embedding.TRAINING + elif (mode == model_fn_lib.ModeKeys.EVAL or + mode == model_fn_lib.ModeKeys.PREDICT): + tpu_embedding_mode = tpu_embedding.INFERENCE + else: + raise ValueError('Mode {} is not supported.'.format(mode)) + + tpu_embedding_ = tpu_embedding.TPUEmbedding( + self._table_to_config_dict, + self._feature_to_table_dict, + batch_size, + tpu_embedding_mode, + self._master, + self._optimization_parameters, + ) + return tpu_embedding_ + + def get_tpu_embedding(self, mode): + if mode not in self._mode_to_tpu_embedding_dict: + self._mode_to_tpu_embedding_dict[mode] = ( + self._create_tpu_embedding(mode)) + return self._mode_to_tpu_embedding_dict[mode] + + +def split_inputs(ctx, features, labels): + """Splits the dense and sparse tensors inside the features and labels.""" + sparse_features = collections.OrderedDict() + if ctx.embedding_config: + tpu_embedding_ = ctx.embedding_config.tpu_embedding + for feature_key in tpu_embedding_.feature_to_table_dict: + sparse_features[feature_key] = features.pop(feature_key) + + return features, labels, sparse_features diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py index 8d6245390fc3fa005c92d01bc9b64ddb47583582..bc0cd41d210ac6f8de1b20ebf744ee1e1dd04137 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets.py @@ -142,15 +142,12 @@ def StreamingFilesDataset(files, source_dataset = source_dataset.shuffle( buffer_size=filename_shuffle_buffer_size) - # NOTE: We perform the `repeat` on the source dataset, because the output - # dataset does not currently have enough information to recreate an iterator - # over the source dataset when it reaches the end. - source_dataset = source_dataset.repeat(num_epochs) - source_dataset = source_dataset.apply( interleave_ops.parallel_interleave( reader_fn, cycle_length=num_parallel_reads, sloppy=sloppy)) + source_dataset = source_dataset.repeat(num_epochs) + if batch_transfer_size: source_dataset = source_dataset.batch(batch_transfer_size) diff --git a/tensorflow/contrib/tpu/python/tpu/device_assignment.py b/tensorflow/contrib/tpu/python/tpu/device_assignment.py index 6906501ecf90c8e577aa0becf2dba818deb19df4..3313dc749c2c7606101b2dc96614df2d052dfed1 100644 --- a/tensorflow/contrib/tpu/python/tpu/device_assignment.py +++ b/tensorflow/contrib/tpu/python/tpu/device_assignment.py @@ -25,6 +25,9 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.tpu.python.tpu.topology import Topology +SINGLE_CORE_ASSIGNMENT = [[[0, 0, 0]]] + + def _compute_task_and_cores_to_replicas(core_assignment, topology): """Computes a nested dict which maps task and logical core to replicas.""" task_and_cores_to_replicas = {} diff --git a/tensorflow/contrib/tpu/python/tpu/feature_column.py b/tensorflow/contrib/tpu/python/tpu/feature_column.py index d5d00d628d407bf3bb5312bd54f6ccd13dc37db4..8edf131bc24fd003806263570b63ee8514c49896 100644 --- a/tensorflow/contrib/tpu/python/tpu/feature_column.py +++ b/tensorflow/contrib/tpu/python/tpu/feature_column.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib import math from tensorflow.contrib.tpu.python.tpu import tpu @@ -279,11 +278,10 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn): def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): if tpu.under_tpu_inference_context(): - # TODO(shizhiw, b/112012627, b/112336539): Replace _outside_all_rewrites() - # with outside compilation. - with _outside_all_rewrites(): + def host_computation(): return fc._EmbeddingColumn._get_dense_tensor( self, inputs, weight_collections, trainable) + return tpu.outside_compilation(host_computation) if _is_running_on_cpu(): return fc._EmbeddingColumn._get_dense_tensor( @@ -300,13 +298,6 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn): return tensor -@contextlib.contextmanager -def _outside_all_rewrites(): - """'Break out' of a tpu.rewrite() (or shard(), etc.).""" - with ops.control_dependencies(None): - yield - - class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._SharedEmbeddingColumn): """Core Shared Embedding Column.""" @@ -385,11 +376,10 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn, def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): if tpu.under_tpu_inference_context(): - # TODO(shizhiw, b/112012627, b/112336539): Replace _outside_all_rewrites() - # with outside compilation. - with _outside_all_rewrites(): + def host_computation(): return fc._SharedEmbeddingColumn._get_dense_tensor( self, inputs, weight_collections, trainable) + return tpu.outside_compilation(host_computation) if _is_running_on_cpu(): return fc._SharedEmbeddingColumn._get_dense_tensor( diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py index 3e463823c820a3ef8628324f77e1a9caf8d385d5..f5735cecc38b7033f21fc4d4105cfead233379fa 100644 --- a/tensorflow/contrib/tpu/python/tpu/session_support.py +++ b/tensorflow/contrib/tpu/python/tpu/session_support.py @@ -185,7 +185,8 @@ def all_worker_devices(session): """Return a list of devices for each worker in the system.""" devices = session.list_devices() return [ - device.name for device in devices + device.name + for device in devices if ':CPU:' in device.name and 'coordinator' not in device.name ] @@ -255,12 +256,14 @@ class WatchdogManager(threading.Thread): self._worker_manager.configure( event_pb2.WorkerHeartbeatRequest( watchdog_config=event_pb2.WatchdogConfig( - timeout_ms=self.shutdown_timeout * 1000,))) + timeout_ms=self.shutdown_timeout * 1000,), + shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)) def configure_and_run(self): - logging.info('Enabling watchdog timer with %d second timeout ' - 'and %d second ping interval.', - self.shutdown_timeout, self.ping_interval) + logging.info( + 'Enabling watchdog timer with %d second timeout ' + 'and %d second ping interval.', self.shutdown_timeout, + self.ping_interval) self._reset_manager() self._running = True self.start() @@ -269,7 +272,8 @@ class WatchdogManager(threading.Thread): logging.info('Stopping worker watchdog.') self._worker_manager.configure( event_pb2.WorkerHeartbeatRequest( - watchdog_config=event_pb2.WatchdogConfig(timeout_ms=-1,))) + watchdog_config=event_pb2.WatchdogConfig(timeout_ms=-1,), + shutdown_mode=event_pb2.NOT_CONFIGURED)) self._running = False self.join() diff --git a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py index a1494e3660bc09e3af45e81097151a35990810fb..bf492e78a15acc92017663a286e8c8f0b2045339 100644 --- a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py +++ b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py @@ -47,6 +47,8 @@ _TRACE_MODE_PART_TENSOR_SIZE = 3 _TRACE_MODE_FULL_TENSOR = 'full-tensor' _TRACE_MODE_NORM = 'norm' _TRACE_MODE_MAX_ABS = 'max-abs' +_SUBMODE_BRIEF = 'brief' +_SUBMODE_DETAILED = 'detailed' _REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' _REASON_UNSAFE_OP = 'not-traced-unsafe-op' _REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar' @@ -57,6 +59,7 @@ _REASON_SCALAR_GET_TRACED = 'traced-scalar' _REASON_TENSOR_GET_TRACED = 'traced-tensor' _REASON_USER_INCLUDED = 'traced-user-included' _REASON_USER_EXCLUDED = 'not-traced-user-excluded' +_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path' _REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor' _MARKER_SECTION_BEGIN = '!!!!!!! section-begin:' _MARKER_SECTION_END = '!!!!!!! section-end:' @@ -68,6 +71,7 @@ _SECTION_NAME_GRAPH = 'graph' _FIELD_NAME_VERSION = 'version:' _FIELD_NAME_DEVICE = 'device:' _FIELD_NAME_TRACE_MODE = 'trace-mode:' +_FIELD_NAME_SUBMODE = 'submode:' _FIELD_NAME_NUM_REPLICAS = 'num-replicas:' _FIELD_NAME_NUM_OPS = 'number-of-ops:' _FIELD_NAME_NUM_TENSORS = 'number-of-tensors:' @@ -76,8 +80,10 @@ _FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' _FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'") _FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"') _FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)') +_FLAG_NO_EQUAL_PAT = re.compile(r'\s*--([^=]+)\s*') _FLAG_NAME_ENABLE = 'enable' _FLAG_NAME_TRACE_MODE = 'trace_mode' +_FLAG_NAME_SUBMODE = 'submode' _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS = 'include_less_interesting_ops' _FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames' _FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes' @@ -94,7 +100,7 @@ _TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables' _TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint' -def tensor_checkpoint(tensor, checkpoint_name): +def tensor_tracepoint(tensor, checkpoint_name): """Adds a checkpoint with the given checkpoint name for the given tensor. The tensor will be added to the list of tensors that will be traced by the @@ -115,10 +121,10 @@ def tensor_checkpoint(tensor, checkpoint_name): return tensor -def keras_layer_checkpoint(layer, checkpoint_name): +def keras_layer_tracepoint(layer, checkpoint_name): """An interface for adding the tensor outputs of a keras layer. - Encapsulates tensor_checkpoint. + Encapsulates tensor_tracepoint. Args: layer: A keras layer. @@ -132,12 +138,12 @@ def keras_layer_checkpoint(layer, checkpoint_name): try: outputs = layer.output if tensor_util.is_tensor(outputs): - tensor_checkpoint(outputs, '%s' % (checkpoint_name)) + tensor_tracepoint(outputs, '%s' % (checkpoint_name)) else: idx = 0 for output_tensor in outputs: if tensor_util.is_tensor(outputs): - tensor_checkpoint(output_tensor, '%s_%d' % (checkpoint_name, idx)) + tensor_tracepoint(output_tensor, '%s_%d' % (checkpoint_name, idx)) idx += 1 except AttributeError: pass @@ -165,21 +171,39 @@ class TensorTracer(object): @staticmethod def _match_next_flag(flags, pos): - """Returns the match for the next TensorTracer flag.""" + """Returns the match for the next TensorTracer flag. + + Args: + flags: a string that contains the flags. + pos: where in flags to start the search. + + Returns: + A pair where the first element is the regular-expression + match found and the second element indicates if the match + has a value. + """ match = _FLAG_DOUBLE_QUOTE_PAT.match(flags, pos) if match: - return match + return match, True match = _FLAG_SINGLE_QUOTE_PAT.match(flags, pos) if match: - return match + return match, True match = _FLAG_NO_QUOTE_PAT.match(flags, pos) - return match + if match: + return match, True + match = _FLAG_NO_EQUAL_PAT.match(flags, pos) + if match: + # The flag is found but is not given a value. + return match, False + # The flag is not found. + return None, False @staticmethod def validate_flag_names(): """Validates if the TensorTrace flags passed are valid.""" valid_flag_names = [_FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE, + _FLAG_NAME_SUBMODE, _FLAG_NAME_EXCLUDED_OPNAMES, _FLAG_NAME_EXCLUDED_OPTYPES, _FLAG_NAME_INCLUDED_OPNAMES, @@ -193,7 +217,7 @@ class TensorTracer(object): return pos = 0 while True: - match = TensorTracer._match_next_flag(tensor_tracer_flags, pos) + match, _ = TensorTracer._match_next_flag(tensor_tracer_flags, pos) if not match: break flag_name = match.group(1) @@ -216,11 +240,15 @@ class TensorTracer(object): result += 'Individual flag value:\n' pos = 0 while True: - match = TensorTracer._match_next_flag(tensor_tracer_flags, pos) + match, has_value = TensorTracer._match_next_flag( + tensor_tracer_flags, pos) if not match: break flag_name = match.group(1) - flag_value = match.group(2) + if has_value: + flag_value = match.group(2) + else: + flag_value = None result += ' %s: %s\n'%(flag_name, flag_value) pos = match.end() result += '\n' @@ -228,30 +256,45 @@ class TensorTracer(object): @staticmethod def get_flag_value(wanted_flag_name): - """Returns the value of a TensorTracer flags.""" + """Returns the value of a TensorTracer flags. + + Args: + wanted_flag_name: the name the the flag we are looking for. + + Returns: + A pair where the first element indicates if the flag is + found and the second element is the value of the flag. + + Raises: + RuntimeError: If supposedly deadcode is reached. + """ tensor_tracer_flags = os.getenv(_FLAGS_ENV_VAR) if not tensor_tracer_flags: - return '' + return False, None pos = 0 while True: - match = TensorTracer._match_next_flag(tensor_tracer_flags, pos) + match, has_value = TensorTracer._match_next_flag( + tensor_tracer_flags, pos) if not match: - return '' + return False, None flag_name = match.group(1) - flag_value = match.group(2) + if has_value: + flag_value = match.group(2) + else: + flag_value = None if flag_name == wanted_flag_name: - return flag_value + return True, flag_value pos = match.end() - return '' + raise RuntimeError('Should not reach here.') @staticmethod def flag_value_to_re_list(flag_name): """Converts list of strings to compiled RE.""" re_list = [] - flag_value = TensorTracer.get_flag_value(flag_name) - if not flag_value: + found, flag_value = TensorTracer.get_flag_value(flag_name) + if not found or not flag_value: return re_list list_of_values = flag_value.split() for v in list_of_values: @@ -260,32 +303,41 @@ class TensorTracer(object): return re_list @staticmethod - def is_enabled(): - """Returns True if TensorTracer is enabled.""" + def _is_flag_on(flag_name): + """Returns True if the given flag is on.""" - flag_value = TensorTracer.get_flag_value(_FLAG_NAME_ENABLE) + found, flag_value = TensorTracer.get_flag_value(flag_name) + if not found: + return False + if flag_value is None: + return True + # Depends on the flag value. flag_value = flag_value.lower() enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] return enabled + @staticmethod + def is_enabled(): + """Returns True if TensorTracer is enabled.""" + + return TensorTracer._is_flag_on(_FLAG_NAME_ENABLE) + @staticmethod def use_test_undeclared_outputs_dir(): - """Decides the output directory of the trace file. + """Decides the output directory of the report and trace files. Args: None. Returns: - True if the output trace file should be written to the + True if the output files should be written to the test-undeclared-outputs-directory defined via an env variable. """ - flag_value = TensorTracer.get_flag_value( + return TensorTracer._is_flag_on( _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR) - flag_value = flag_value.lower() - enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] - return enabled + @staticmethod def check_device_type(device_type): @@ -306,6 +358,18 @@ class TensorTracer(object): 'Valid trace modes are: %s'%(trace_mode, valid_trace_modes)) + @staticmethod + def check_submode(submode): + """Checks if the given submode is valid.""" + + if not submode: + return + valid_submodes = [_SUBMODE_DETAILED, _SUBMODE_BRIEF] + if submode not in valid_submodes: + raise ValueError('Invalid submode "%s" given to the Tensor_Tracer.' + 'Valid submodes are: %s'%(submode, + valid_submodes)) + @staticmethod def unsafe_op(op): """Returns True if this op is not safe to be traced.""" @@ -314,8 +378,7 @@ class TensorTracer(object): return True # Reasons for not including following op types: # Assign: cause incorrect result with CPU tracing. - # others: compilation problems. - if op.type in ['Assign', 'Pack', 'Shape', 'Reshape', 'ArgMin', 'ArgMax']: + if op.type in ['Assign']: return True return False @@ -350,10 +413,12 @@ class TensorTracer(object): def less_interesting_op(op): """Returns True if the given Op is not an interesting one to be traced.""" - include_less_interesting = TensorTracer.get_flag_value( + found, _ = TensorTracer.get_flag_value( _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS) - if include_less_interesting: + if found: + # users force to include all ops. return False + # Following ops are highly unlikey to cause bugs. return op.type in ['Const', 'Identity', 'Cast', 'Shape'] @staticmethod @@ -404,7 +469,7 @@ class TensorTracer(object): temporarily_marked_ops, sorted_ops) # pylint: disable=protected-access for ctrl_output_op in op._control_outputs: - # pylint: enable=protected-access + # pylint: enable=protected-access visit(ctrl_output_op, cycle, permanently_marked_ops, temporarily_marked_ops, sorted_ops) temporarily_marked_ops.remove(op) @@ -460,10 +525,14 @@ class TensorTracer(object): self._version = 'use-outside-compilation' self._device_type = None TensorTracer.validate_flag_names() - self._trace_mode = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_MODE) - if not self._trace_mode: + found, self._trace_mode = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_MODE) + if not found or not self._trace_mode: self._trace_mode = _TRACE_MODE_NAN_INF TensorTracer.check_trace_mode(self._trace_mode) + found, self._submode = TensorTracer.get_flag_value(_FLAG_NAME_SUBMODE) + if not found or not self._submode: + self._submode = _SUBMODE_DETAILED + TensorTracer.check_submode(self._submode) self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE self._instrument_records = {} self._set_trace_file_path() @@ -499,8 +568,10 @@ class TensorTracer(object): def _set_trace_file_path(self): """Sets the path of the output trace file.""" - self._trace_file_path = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_FILE) - if self._trace_file_path and TensorTracer.use_test_undeclared_outputs_dir(): + found, self._trace_file_path = TensorTracer.get_flag_value( + _FLAG_NAME_TRACE_FILE) + if found and self._trace_file_path \ + and TensorTracer.use_test_undeclared_outputs_dir(): if os.path.isabs(self._trace_file_path): raise ValueError('If use_test_undeclared_outputs_dir is set,' 'trace_file_path cannot be an absolute path (%s)' @@ -512,7 +583,17 @@ class TensorTracer(object): def _set_report_file(self): """Sets the path of the output report file.""" - self._report_file_path = TensorTracer.get_flag_value(_FLAG_NAME_REPORT_FILE) + found, self._report_file_path = TensorTracer.get_flag_value( + _FLAG_NAME_REPORT_FILE) + if found and self._report_file_path \ + and TensorTracer.use_test_undeclared_outputs_dir(): + if os.path.isabs(self._report_file_path): + raise ValueError('If use_test_undeclared_outputs_dir is set,' + 'report_file_path cannot be an absolute path (%s)' + %self._report_file_path) + outputs_dir = os.environ.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) + self._report_file_path = os.path.join(outputs_dir, + self._report_file_path) if not self._report_file_path: self._report_file = None return @@ -528,8 +609,8 @@ class TensorTracer(object): def _set_op_range(self): """Sets the index range of the Ops that we will consider tracing.""" - op_range = TensorTracer.get_flag_value(_FLAG_NAME_OP_RANGE) - if not op_range: + found, op_range = TensorTracer.get_flag_value(_FLAG_NAME_OP_RANGE) + if not found or not op_range: self._op_range = (-1, -1) # this means including all ops. return match = _OP_RANGE_PAT.match(op_range) @@ -595,6 +676,7 @@ class TensorTracer(object): self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, self._version)) self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, self._device_type)) self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE, self._trace_mode)) + self._write_report('%s %s\n'%(_FIELD_NAME_SUBMODE, self._submode)) self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS, self._num_replicas)) self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG)) @@ -606,7 +688,7 @@ class TensorTracer(object): self._write_report('"%s" %s\n'%(key, self._instrument_records[key])) self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON)) - def _write_op_list_section(self, op_list, tensorname_idx_map): + def _write_op_list_section(self, op_list): """Writes the Op-list section of the report.""" self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST)) @@ -615,10 +697,10 @@ class TensorTracer(object): op = op_list[i] line = '%d "%s" %s'%(i, op.name, op.type) for out_tensor in op.outputs: - if out_tensor.name not in tensorname_idx_map: + if out_tensor.name not in self._tensorname_idx_map: raise ValueError( 'out_tensor %s is not in tensorname_idx_map'%out_tensor.name) - line += ' %d'%tensorname_idx_map[out_tensor.name] + line += ' %d'%self._tensorname_idx_map[out_tensor.name] line += '\n' self._write_report(line) self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST)) @@ -653,12 +735,64 @@ class TensorTracer(object): self._write_report('%d "%s"\n'%(i, l[i].name)) self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH)) - def _make_tensor_trace_fun(self, op_name, output_idx): + def _preprocess_traced_tensor(self, tensor): + """Computes NAN/Norm/Max on TPUs before sending to CPU. + + Args: + tensor: The tensor to be traced. + Returns: + A tensor that should be input to the trace_function. + Raises: + RuntimeError: If the trace mode is invalid. + """ + + def _detect_nan_inf(tensor): + """Trace function for detecting any NaN/Inf in the tensor.""" + + if tensor.dtype.is_floating: + output_tensor = math_ops.reduce_any( + gen_math_ops.logical_or( + gen_math_ops.is_nan(tensor), gen_math_ops.is_inf(tensor))) + else: + output_tensor = constant_op.constant(False) + # The shape has to be 1. Set it if it does not have the information. + output_tensor = array_ops.reshape(output_tensor, [1]) + return output_tensor + + def _show_norm(tensor): + tensor = math_ops.cast(tensor, dtypes.float32) + output_tensor = linalg_ops.norm(tensor) + # The shape has to be 1. Set it if it does not have the information. + output_tensor = array_ops.reshape(output_tensor, [1]) + return output_tensor + + def _show_max_abs(tensor): + tensor = math_ops.cast(tensor, dtypes.float32) + output_tensor = math_ops.reduce_max(math_ops.abs(tensor)) + zero = constant_op.constant(0, dtypes.float32) + output_tensor = gen_math_ops.maximum(zero, output_tensor) + # The shape has to be 1. Set it if it does not have the information. + output_tensor = array_ops.reshape(output_tensor, [1]) + return output_tensor + + if self._trace_mode == _TRACE_MODE_NAN_INF: + return _detect_nan_inf(tensor) + if self._trace_mode == _TRACE_MODE_PART_TENSOR: + return tensor + if self._trace_mode == _TRACE_MODE_FULL_TENSOR: + return tensor + if self._trace_mode == _TRACE_MODE_NORM: + return _show_norm(tensor) + if self._trace_mode == _TRACE_MODE_MAX_ABS: + return _show_max_abs(tensor) + raise RuntimeError( + 'Tensor trace fun for %s is not yet implemented' % self._trace_mode) + + def _make_tensor_trace_fun(self, tensor_name): """Makes the tensor tracing function called by outside compilation. Args: - op_name: the name of the Op that outputs the tensor to be traced. - output_idx: which output of the Op it is (0 means the first output). + tensor_name: name of the tensor being traced. Returns: A function to be passed as the first argument to outside compilation. @@ -667,84 +801,72 @@ class TensorTracer(object): RuntimeError: If the trace mode is invalid. """ - def _print_tensor(op_name, output_idx, num_elements, tensor, output_tensor): + def _print_tensor(tensor_name, num_elements, tensor, output_tensor): """Prints a tensor value to a file. Args: - op_name: the name of the Op that outputs the tensor to be printed. - output_idx: which output of the Op it is (0 means the first output). + tensor_name: name of the tensor being traced. num_elements: number of elements to print (-1 means print all). tensor: the tensor needs to be returned. output_tensor: the tensor needs to be printed. Returns: The same tensor passed via the "tensor" argument. + + Raises: + ValueError: If tensor_name is not already in + self._tensorname_idx_map. """ - msg = '"%s:%d" '%(op_name, output_idx) + + if self._submode == _SUBMODE_BRIEF: + if tensor_name not in self._tensorname_idx_map: + raise ValueError( + 'Tensor name %s is not in the tensorname_idx_map'%tensor_name) + msg = '%d'%self._tensorname_idx_map[tensor_name] + else: + msg = '"%s"'%tensor_name + if self._trace_file_path: output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path else: output_stream = sys.stderr print_op = logging_ops.print_v2(msg, array_ops.shape(output_tensor), - ' @', self._replica_id, + '@', self._replica_id, '\n', output_tensor, '\n', summarize=num_elements, output_stream=output_stream) with ops.control_dependencies([print_op]): return array_ops.identity(tensor).op - def _detect_nan_inf(tensor): - """Trace function for detecting any NaN/Inf in the tensor.""" - - if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__( - dtypes.float16): - # Since host can't handle bf16, always convert tensor to f32. - tensor = math_ops.cast(tensor, dtypes.float32) - output_tensor = math_ops.reduce_any( - gen_math_ops.logical_or(gen_math_ops.is_nan(tensor), - gen_math_ops.is_inf(tensor))) - else: - output_tensor = constant_op.constant(0) - return _print_tensor(op_name, output_idx, -1, tensor, output_tensor) - - def _show_norm(tensor): - tensor = math_ops.cast(tensor, dtypes.float64) - output_tensor = linalg_ops.norm(tensor) - return _print_tensor(op_name, output_idx, -1, tensor, output_tensor) - - def _show_max_abs(tensor): - output_tensor = math_ops.cast(math_ops.reduce_max(math_ops.abs(tensor)), - dtypes.float64) - zero = constant_op.constant(0, dtypes.float64) - output_tensor = gen_math_ops.maximum(zero, output_tensor) - return _print_tensor(op_name, output_idx, -1, tensor, output_tensor) def _show_part_tensor(tensor): """Trace function for printing part of the tensor.""" - return _print_tensor(op_name, output_idx, self._part_tensor_size, + return _print_tensor(tensor_name, self._part_tensor_size, tensor, tensor) def _show_full_tensor(tensor): """Trace function for printing the entire tensor.""" - return _print_tensor(op_name, output_idx, -1, tensor, tensor) + return _print_tensor(tensor_name, -1, tensor, tensor) - if self._trace_mode == _TRACE_MODE_NAN_INF: - return _detect_nan_inf if self._trace_mode == _TRACE_MODE_PART_TENSOR: return _show_part_tensor - if self._trace_mode == _TRACE_MODE_FULL_TENSOR: + # The input tensor has a shape of "[1]" for _TRACE_MODE_NAN_INF, + # _TRACE_MODE_NORM, and _TRACE_MODE_MAX_ABS, as related computations are + # performed within TPUs and only their results are transferred to CPU. + # Simply, print the full tensor for these trace modes. + if self._trace_mode in [ + _TRACE_MODE_NAN_INF, _TRACE_MODE_NORM, _TRACE_MODE_FULL_TENSOR, + _TRACE_MODE_MAX_ABS + ]: return _show_full_tensor - if self._trace_mode == _TRACE_MODE_NORM: - return _show_norm - if self._trace_mode == _TRACE_MODE_MAX_ABS: - return _show_max_abs raise RuntimeError('Tensor trace fun for %s is not yet implemented' %self._trace_mode) - def _skip_op(self, op_id, op, user_included, user_excluded): + def _skip_op(self, op_id, op, user_included, user_excluded, + in_exec_path=True): """Returns True if we should not trace Op.""" if user_included: @@ -755,6 +877,10 @@ class TensorTracer(object): self._instrument_records[op.name] = TensorTracer.reason( op_id, _REASON_USER_EXCLUDED) return True + if not in_exec_path: + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_NOT_EXECUTED) + return True if not self._inside_op_range(op_id): self._instrument_records[op.name] = TensorTracer.reason( op_id, _REASON_OUTSIDE_OP_RANGE) @@ -797,9 +923,18 @@ class TensorTracer(object): op_id, _REASON_USER_EXCLUDED) return True if not out_tensor.get_shape().is_fully_defined(): - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _REASON_DYNAMIC_SHAPE) - return True + # If trace mode is nan-inf, norm or max, then the tensor will be reduced + # to a scalar before the outside compilation call. + if self._trace_mode in [ + _TRACE_MODE_NAN_INF, _TRACE_MODE_NORM, _TRACE_MODE_MAX_ABS + ]: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_TENSOR_GET_TRACED) + return False + else: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_DYNAMIC_SHAPE) + return True rank = len(out_tensor.shape) if rank < 1: # scalar @@ -817,14 +952,48 @@ class TensorTracer(object): op_id, _REASON_TENSOR_GET_TRACED) return False + def _filter_execution_path_operations(self, operations, fetches): + """Returns the set of ops in the execution path to compute given fetches.""" + # If no fetch provided, then return all operations. + if fetches is None: + return set(operations) + # Convert to list, if a single element is provided. + if not isinstance(fetches, (list, tuple)): + fetches = [fetches] + # If a tensor is given as fetch, convert it to op. + op_fetches = [] + for fetch in fetches: + if isinstance(fetch, ops.Operation): + op_fetches.append(fetch) + elif isinstance(fetch, ops.Tensor): + op_fetches.append(fetch.op) + else: + raise RuntimeError('Given fetch:%s is neither a tensor nor an op.' + %fetch) + + execution_path_operations = set(op_fetches) + traverse_stack = list(op_fetches) + while True: + if not traverse_stack: + break + head_op = traverse_stack.pop() + input_ops = [tensor_input.op for tensor_input in head_op.inputs] + input_ops.extend(head_op.control_inputs) + + for input_op in input_ops: + if input_op not in execution_path_operations: + execution_path_operations.add(input_op) + traverse_stack.append(input_op) + return execution_path_operations + def _pre_tracing(self, graph): """Work needs to be done prior to TPU or CPU tracing.""" operations = graph.get_operations() - (opname_idx_map, tensor_list, tensorname_idx_map) = ( + (opname_idx_map, tensor_list, self._tensorname_idx_map) = ( TensorTracer._make_op_and_tensor_maps(operations)) self._write_config_section() - self._write_op_list_section(operations, tensorname_idx_map) + self._write_op_list_section(operations) self._write_tensor_list_section(tensor_list, opname_idx_map) # Does the topological sort before adding any nodes to the graph. (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) @@ -858,13 +1027,15 @@ class TensorTracer(object): _TENSOR_TRACER_CHECKPOINT)) return checkpoint_operations - def trace_tpu(self, graph, result_tensor, num_replicas=None): + def trace_tpu(self, graph, result_tensor, num_replicas=None, fetches=None): """Traces the tensors generated by TPU Ops in a TF graph. Args: graph: the graph of Ops executed on the TPU. result_tensor: a result tensor of evaluating the graph. num_replicas: number of replicas used on the TPU. + fetches: the list of fetches given to session.run, used to determine the + ops in execution path. If None, the whole graph will be traced. Returns: A tuple (result_tensor_copy, tracing_ops), where: @@ -876,11 +1047,27 @@ class TensorTracer(object): graph is evaluated. """ + def _cast_unsupported_dtypes(tensor): + """Casts tensor to a supported type.""" + + if tensor.dtype.__eq__(dtypes.int64): + # outside-compilation doesn't support int64 input yet. + return math_ops.cast(tensor, dtypes.int32) + if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__( + dtypes.float16): + # Since host can't handle bf16, convert tensor to f32. + return math_ops.cast(tensor, dtypes.float32) + return tensor + self._device_type = _DEVICE_TYPE_TPU TensorTracer.check_device_type(self._device_type) result_tensor_copy = self._add_replica_id_to_graph(num_replicas, result_tensor) (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph) + # Filter out the operations that won't be executed. + # if fetches=None, then ops_in_exec_path = set(operations) + ops_in_exec_path = self._filter_execution_path_operations(operations, + fetches) tracing_ops = [] checkpoint_operations = self._get_checkpoints(graph) @@ -889,16 +1076,23 @@ class TensorTracer(object): continue user_included = self._is_user_included_op(op) user_excluded = self._is_user_excluded_op(op) - if self._skip_op(op_id, op, user_included, user_excluded): + in_exec_path = op in ops_in_exec_path + if self._skip_op(op_id, op, user_included, user_excluded, in_exec_path): continue for i in range(len(op.outputs)): out_tensor = op.outputs[i] if self._skip_tensor(op_id, out_tensor, user_included, user_excluded): continue + # Create the list of consumers before calling _preprocess_traced_tensor. + # Otherwise, adding control input below, will introduce a cycle in the + # graph. consumers = out_tensor.consumers() + tensor_name = out_tensor.name + processed_out_tensor = self._preprocess_traced_tensor(out_tensor) + processed_out_tensor = _cast_unsupported_dtypes(processed_out_tensor) trace_op = tpu.outside_compilation( - self._make_tensor_trace_fun(op.name, i), out_tensor) + self._make_tensor_trace_fun(tensor_name), processed_out_tensor) if consumers: for consumer_op in consumers: # pylint: disable=protected-access @@ -944,8 +1138,9 @@ class TensorTracer(object): if self._skip_tensor(op_id, out_tensor, user_included, user_excluded): continue - trace_fun = self._make_tensor_trace_fun(op.name, i) - trace_call = (trace_fun, [out_tensor]) + processed_out_tensor = self._preprocess_traced_tensor(out_tensor) + trace_fun = self._make_tensor_trace_fun(out_tensor.name) + trace_call = (trace_fun, [processed_out_tensor]) trace_call_key = 'tensor_tracing_cpu-%s:%d'%(op.name, i) tracing_calls[trace_call_key] = trace_call self._post_tracing(succeed, sorted_or_cycle) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 9266d81cf5fc035790062f0e307a5da0b01a9fc1..ebbccea02c70f06ac3e1231a359f2df4ebad3142 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -19,23 +19,29 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.compiler import xla from tensorflow.contrib.framework.python.framework import experimental +from tensorflow.contrib.tpu.proto import dynamic_padding_pb2 as dynamic_padding from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.compat import compat as api_compat from tensorflow.python.framework import device as pydev +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat +from tensorflow.python.util import nest # Operations that indicate some error in the users graph, e.g. a placeholder @@ -480,14 +486,19 @@ def replicate(computation, inputs=None, infeed_queue=None, device_assignment=None, - name=None): + name=None, + maximum_shapes=None): """Builds a graph operator that runs a replicated TPU computation. Args: computation: A Python function that builds the computation to replicate. inputs: A list of lists of input tensors or `None` (equivalent to `[[]]`), indexed by `[replica_num][input_num]`. All replicas must - have the same number of inputs. + have the same number of inputs. Each input can be a nested structure + containing values that are convertible to tensors. Note that passing an + N-dimension list of compatible values will result in a N-dimention list of + scalar tensors rather than a single Rank-N tensors. If you need different + behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to computation. device_assignment: If not `None`, a `DeviceAssignment` describing the @@ -497,15 +508,125 @@ def replicate(computation, only one core, and there is either only one replica, or the number of replicas is equal to the number of cores in the TPU system. name: (Deprecated) Does nothing. + maximum_shapes: A nested structure of tf.TensorShape representing the shape + to which the respective component of each input element in each replica + should be padded. Any unknown dimensions (e.g. tf.Dimension(None) in a + tf.TensorShape or -1 in a tensor-like object) will be padded to the + maximum size of that dimension over all replicas. Note that if the input + dimension is already static, we won't do padding on it and we require the + maximum_shapes to have the same value or None on that dimension. The + structure of `maximum_shapes` needs to be the same as `inputs[0]`. Returns: - A list of lists of output tensors, indexed by `[replica_num][output_num]`. + A list of outputs, indexed by `[replica_num]` each output can be a nested + structure same as what computation() returns with a few exceptions. + + Exceptions include: + 1) None output: a NoOp would be returned which control-depends on + computation. + 2) Single value output: A tuple containing the value would be returned. + 3) Operation-only outputs: a NoOp would be returned which + control-depends on computation. + TODO(b/121383831): Investigate into removing these special cases. + Raises: ValueError: If all replicas do not have equal numbers of input tensors. ValueError: If the number of inputs per replica does not match the number of formal parameters to `computation`. + ValueError: If the static `inputs` dimensions don't match with the values + given in `maximum_shapes`. + ValueError: If the structure of inputs per replica does not match + the structure of `maximum_shapes`. """ - return split_compile_and_replicate(computation, inputs, infeed_queue, - device_assignment, name)[1] + return split_compile_and_replicate( + computation, + inputs, + infeed_queue, + device_assignment, + name, + maximum_shapes=maximum_shapes)[1] + + +def _pad_all_input(inputs, padded_shapes): + """Pad all input tensors given padded_shapes. + + The real shape tensors will be concatenated with the padded original inputs. + + Args: + inputs: The original inputs. + padded_shapes: A list of padded shapes for each input. + + Returns: + The padded inputs and a PaddingMap list which maps the padded input + dimension to the real shape argument index. + """ + input_shape_tensors = [] + for core_idx, inputs_per_core in enumerate(inputs): + for idx, input_tensor in enumerate(inputs_per_core): + if core_idx == 0: + input_shape_tensors.append([]) + input_shape_tensors[idx].append(array_ops.shape(input_tensor)) + + maximum_shapes = [] + for shapes_per_input in input_shape_tensors: + maximum_shapes.append( + math_ops.reduce_max(array_ops.stack(shapes_per_input), axis=0)) + + padded_inputs = [] + real_shapes = [] + padding_maps = [] + for core_idx, inputs_per_core in enumerate(inputs): + padded_inputs.append([]) + real_shapes.append([]) + real_shape_idx = len(inputs_per_core) - 1 + for idx, input_tensor in enumerate(inputs_per_core): + input_shape_tensor = input_shape_tensors[idx][core_idx] + input_shape = input_tensor.get_shape() + padded_shape = padded_shapes[idx] + + # The static shape of inputs should be compatible with the given padded + # shapes. + input_shape.assert_is_compatible_with(padded_shape) + + if input_shape.is_fully_defined(): + # Do nothing if the shape of the whole tensor is already static. + padded_inputs[core_idx].append(input_tensor) + else: + # Only pad the non static shape dimension. + for i, s in enumerate(input_shape): + if s.value is None: + if core_idx == 0: + real_shape_idx += 1 + padding_map = dynamic_padding.PaddingMap() + padding_map.arg_index = idx + padding_map.shape_index = i + padding_map.padding_arg_index = real_shape_idx + padding_maps.append(padding_map) + real_shapes[core_idx].append( + math_ops.cast(input_shape_tensor[i], dtypes.uint32)) + + paddings = [] + for i, s in enumerate(padded_shape): + if input_shape[i].value: + # Don't pad if input shape is already static. + padding = [0, 0] + else: + if s.value: + # Pad to the given maximum value. + padding = [0, s.value - input_shape_tensor[i]] + else: + # If maximum value is not given, then pad to the maximum dimension + # among all the cores. + padding = [0, maximum_shapes[idx][i] - input_shape_tensor[i]] + paddings.append(padding) + + padded_input = array_ops.pad(input_tensor, paddings) + padded_inputs[core_idx].append(padded_input) + + num_replicas = len(padded_inputs) + for i in range(num_replicas): + padded_inputs[i].extend(real_shapes[i]) + + return padded_inputs, padding_maps def split_compile_and_replicate(computation, @@ -513,7 +634,8 @@ def split_compile_and_replicate(computation, infeed_queue=None, device_assignment=None, name=None, - use_tpu=True): + use_tpu=True, + maximum_shapes=None): """Builds graph operators that runs compilation and replicated computation. This is a lower level interface than replicate that returns a separate compile @@ -526,7 +648,11 @@ def split_compile_and_replicate(computation, computation: A Python function that builds the computation to replicate. inputs: A list of lists of input tensors or `None` (equivalent to `[[]]`), indexed by `[replica_num][input_num]`. All replicas must - have the same number of inputs. + have the same number of inputs. Each input can be a nested structure + containing values that are convertible to tensors. Note that passing an + N-dimension list of compatible values will result in a N-dimention list of + scalar tensors rather than a single Rank-N tensors. If you need different + behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to computation. device_assignment: If not `None`, a `DeviceAssignment` describing the @@ -539,6 +665,15 @@ def split_compile_and_replicate(computation, use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU backends. Currently, only supports a default placement (computation is placed on GPU if one is available, and on CPU if not). + maximum_shapes: A nested structure of tf.TensorShape representing the shape + to which the respective component of each input element in each replica + should be padded. Any unknown dimensions (e.g. tf.Dimension(None) in a + tf.TensorShape or -1 in a tensor-like object) will be padded to the + maximum size of that dimension over all replicas. Note that if the input + dimension is already static, we won't do padding on it and we require the + maximum_shapes to have the same value or None on that dimension. The + structure of `maximum_shapes` needs to be the same as `inputs[0]`. + Returns: A list of lists with the first list corresponding to the compile op and the second a list of output tensors, indexed by `[replica_num][output_num]`. @@ -546,6 +681,10 @@ def split_compile_and_replicate(computation, ValueError: If all replicas do not have equal numbers of input tensors. ValueError: If the number of inputs per replica does not match the number of formal parameters to `computation`. + ValueError: If the static `inputs` dimensions don't match with the values + given in `maximum_shapes`. + ValueError: If the structure of inputs per replica does not match + the structure of `maximum_shapes`. """ del name inputs = [[]] if inputs is None else inputs @@ -580,24 +719,32 @@ def split_compile_and_replicate(computation, if num_replicas == 0: return [] + # Checks all replicas have the same structure. + for i in xrange(1, num_replicas): + nest.assert_same_structure(inputs[0], inputs[i]) + + # Flatten inputs. + flat_inputs = [ + nest.flatten(per_replica_input) for per_replica_input in inputs + ] # Converts inputs to Tensors. - inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs] + flat_inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in flat_inputs] # Verifies that all replicas have matching numbers and types of inputs - input_types = [x.dtype for x in inputs[0]] - input_arity = len(input_types) + flat_input_types = [x.dtype for x in flat_inputs[0]] + input_arity = len(inputs[0]) + flat_input_arity = len(flat_input_types) for i in range(num_replicas): if len(inputs[i]) != input_arity: raise ValueError("Replicas must have the same number of inputs. " "Replica 0 had {} inputs, replica {} had {} " "inputs.".format(input_arity, i, len(inputs[i]))) - types = [x.dtype for x in inputs[i]] - if types != input_types: - raise ValueError( - "Replicas must have matching input types. Replica 0 had " - "input types {}, replica {} had input types {}".format( - input_types, i, types)) + types = [x.dtype for x in flat_inputs[i]] + if types != flat_input_types: + raise ValueError("Replicas must have matching input types. Replica 0 had " + "input types {}, replica {} had input types {}".format( + flat_input_types, i, types)) arg_error = xla.check_function_argument_count( computation, input_arity, infeed_queue) @@ -616,13 +763,34 @@ def split_compile_and_replicate(computation, for i in inputs[0]]), infeed_queue.number_of_tuple_elements, arg_error)) + if maximum_shapes: + if infeed_queue: + raise ValueError( + "Dynamic input shapes are not supported with infeed queues") + + # Make sure maximum_shapes has the same structure as inputs. + nest.assert_same_structure(inputs[0], maximum_shapes, check_types=False) + + # Flatten padded shapes. + flat_maximum_shapes = nest.flatten(maximum_shapes) + flat_maximum_shapes = [ + tensor_shape.TensorShape(s) for s in flat_maximum_shapes + ] + + flat_inputs, padding_maps = _pad_all_input(flat_inputs, flat_maximum_shapes) + + serialized_padding_maps = [] + for padding_map in padding_maps: + serialized_padding_maps.append(padding_map.SerializeToString()) + metadata_kwargs["padding_map"] = serialized_padding_maps + graph = ops.get_default_graph() # Fan-in: Builds a TPUReplicatedInput node for each input. - computation_inputs = [] - for i in range(0, input_arity): - replicas = [inputs[replica][i] for replica in xrange(num_replicas)] - computation_inputs.append( + flat_replicated_inputs = [] + for i in range(0, len(flat_inputs[0])): + replicas = [flat_inputs[replica][i] for replica in xrange(num_replicas)] + flat_replicated_inputs.append( tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) cluster_name = graph.unique_name("cluster") @@ -642,15 +810,27 @@ def split_compile_and_replicate(computation, # computation. This is to avoid orphaned TPUReplicatedInput nodes. # TODO(phawkins): consider instead pruning unused TPUReplicatedInput # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. - computation_inputs = [ + flat_replicated_inputs = [ array_ops.identity(x, name="replicated_input_{}".format(i)) - for i, x in enumerate(computation_inputs) + for i, x in enumerate(flat_replicated_inputs) ] - for i in computation_inputs: + for i in flat_replicated_inputs: # pylint: disable=protected-access - i.op._set_attr("_tpu_input_identity", attr_value_pb2.AttrValue(b=True)) + # Add an attribute to the identity node so that they could be removed in + # encapsulate TPU computation pass if unused. However we don't remove + # inputs when dynamic padding is enabled. + # TODO(rxsang): Use other ways except argument index in padding_map so + # outside compilation can work with dynamic padding correctly. + if maximum_shapes is None: + i.op._set_attr("_tpu_input_identity", + attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access + # Unflatten the computation inputs to match original input structure. + computation_inputs = nest.pack_sequence_as( + structure=inputs[0], + flat_sequence=flat_replicated_inputs[:flat_input_arity]) + # If there is an infeed queue, adds the dequeued values to the # computation's inputs. if infeed_queue is not None: @@ -691,51 +871,12 @@ def split_compile_and_replicate(computation, vscope.set_use_resource(saved_use_resource) vscope.set_custom_getter(saved_custom_getter) - # If the computation returns `None`, make it an empty tuple. - if outputs is None: - outputs = tuple() - # If the computation only returned one value, makes it a tuple. - if not isinstance(outputs, (list, tuple)): - outputs = (outputs,) - - # Append `no_op` here so that fetching any return value of this function - # will trigger TPUExecute node. - outputs += (control_flow_ops.no_op(),) - try: - with ops.device(core(0)): - outputs = [ - o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) - for o in outputs - ] - except Exception as e: - raise ValueError( - "TPU function return values must all either be Operations or " - "convertible to Tensors. Got '%s'" % str(e)) - - # Separates the returned Operations and Tensors. - output_operations = [o for o in outputs if isinstance(o, ops.Operation)] - output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] - - if outputs != output_tensors + output_operations: - raise ValueError( - "TPU functions must return zero-or more Tensor values followed by " - "zero or more Operations.") - output_arity = len(output_tensors) + outputs_is_flat = xla.is_flat(outputs) + if outputs_is_flat: + output_tensors, control_deps = _postprocess_flat_outputs(outputs) + else: + output_tensors, control_deps = _postprocess_non_flat_outputs(outputs) - # Wraps outputs in Identity ops. Otherwise a replicated input copied - # straight to an output would bypass the replicate(). This would be bad - # because the TPUReplicatedInput/TPUReplicatedOutput operator would not - # be rewritten away, leading to a runtime error. - # TODO(phawkins): extend the rewrite to elide these nodes instead. - new_output_tensors = [] - for t in output_tensors: - with ops.device(t.device if t.device else core(0)): - o = array_ops.identity(t) - # pylint: disable=protected-access - o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) - # pylint: enable=protected-access - new_output_tensors.append(o) - output_tensors = new_output_tensors context.ExitResult(output_tensors) finally: context.report_unsupported_operations() @@ -747,11 +888,6 @@ def split_compile_and_replicate(computation, attr_value.list.s.extend([compat.as_bytes(x) for x in host_compute_core]) metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access - # Fan-out: Builds a TPUReplicatedOutput node for each output. - outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas, - name="output{}".format(i)) - for i in xrange(output_arity)] - with ops.control_dependencies([metadata]): if use_tpu: compile_status = tpu_ops.tpu_compilation_result() @@ -761,28 +897,146 @@ def split_compile_and_replicate(computation, else: compile_status = control_flow_ops.no_op(name="compilation_status") - with ops.control_dependencies(output_operations): - if output_arity == 0: - # Returns a list of NoOps dependent on the replication Op, indexed by - # [replica_num]. - return [ - compile_status, [ - control_flow_ops.no_op(name="shard_%d" % i) - for i in range(num_replicas) - ] - ] - else: - # Wraps the outputs in identity operators so the names of any possible - # `fetch` nodes are preserved by the replication rewrite. - return [ - compile_status, [[ - array_ops.identity( - outputs[out][replica], - name="output_%d_shard_%d" % (out, replica)) - for out in xrange(output_arity) - ] - for replica in xrange(num_replicas)] + if not output_tensors: + # Returns a list of NoOps dependent on the replication Op, indexed by + # [replica_num]. + return [ + compile_status, + [ + control_flow_ops.group(control_deps, name="shard_%d" % i) + for i in range(num_replicas) + ] + ] + + # Fan-out: Builds a TPUReplicatedOutput node for each output. + replicated_outputs = [[] for i in xrange(num_replicas)] + for i, t in enumerate(output_tensors): + # Fan-out: Builds a TPUReplicatedOutput node for each output. + ys = tpu_ops.tpu_replicated_output( + t, num_replicas, name="output{}".format(i)) + + # Wraps the outputs in identity operators so the names of any possible + # `fetch` nodes are preserved by the replication rewrite. + with ops.control_dependencies(control_deps): + for replica in xrange(num_replicas): + replicated_outputs[replica].append( + array_ops.identity( + ys[replica], name="output_%d_shard_%d" % (i, replica))) + + if not outputs_is_flat: + replicated_outputs = [ + nest.pack_sequence_as(outputs, replica_outs) + for replica_outs in replicated_outputs + ] + + return [compile_status, replicated_outputs] + + +def _postprocess_flat_outputs(outputs): + """Validates non-flat outputs, add backs device assignments and other attrs. + + Args: + outputs: Output from `computation` inside `tpu.rewrite`. + + Returns: + Tensors and Operations extracted from outputs. + """ + # Following code segment is to preserve legacy behavior. Previously we only + # supported flat outputs and thus for consistency it was nice to convert even + # single element into a tuple. But now that we support arbitrary output + # structure, this is no longer necessary. + # TODO(b/121383831): Migrate all legacy use cases and delete this special + # case. + # If the computation returns `None`, make it an empty tuple. + if outputs is None: + outputs = tuple() + # If the computation only returned one value, makes it a tuple. + if not isinstance(outputs, collections.Sequence): + outputs = (outputs,) + + # Append `no_op` here so that fetching any return value of this function + # will trigger TPUExecute node. + outputs += (control_flow_ops.no_op(),) + try: + with ops.device(core(0)): + outputs = [ + o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) + for o in outputs ] + except Exception as e: + raise ValueError( + "TPU function return values must all either be Operations or " + "convertible to Tensors. Got '%s'" % str(e)) + + # Separates the returned Operations and Tensors. + output_operations = [o for o in outputs if isinstance(o, ops.Operation)] + output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] + + if outputs != output_tensors + output_operations: + raise ValueError( + "TPU functions must return zero-or more Tensor values followed by " + "zero or more Operations.") + + # Wraps outputs in Identity ops. Otherwise a replicated input copied + # straight to an output would bypass the replicate(). This would be bad + # because the TPUReplicatedInput/TPUReplicatedOutput operator would not + # be rewritten away, leading to a runtime error. + # TODO(phawkins): extend the rewrite to elide these nodes instead. + new_output_tensors = [] + for t in output_tensors: + with ops.device(t.device if t.device else core(0)): + o = array_ops.identity(t) + # pylint: disable=protected-access + o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) + # pylint: enable=protected-access + new_output_tensors.append(o) + return new_output_tensors, output_operations + + +def _postprocess_non_flat_outputs(outputs): + """Validates non-flat outputs, add backs device assignments and other attrs. + + Args: + outputs: Output from `computation` inside `tpu.rewrite`. + + Returns: + Tensors extracted from outputs and an empty list because Operations are not + allowed in non-flat outputs.. + """ + + # Flatten output items. + flat_outputs = nest.flatten(outputs) + + # Convert all non-Operation outputs to Tensors. + for i, o in enumerate(flat_outputs): + if isinstance(o, ops.Operation): + raise ValueError( + "tpu.rewrite does not support Operation as return value in non-flat " + "output structure. You can set returned Operations as control " + "dependencies of returned Tensors so Operations are triggered when " + 'Tensors are evaluated. Operation found: "%s"' % o.name) + + try: + o = ops.convert_to_tensor(o) + except Exception as e: + raise ValueError( + "TPU function return values must all either be Operations or " + 'convertible to Tensors. Got error: "%s"' % str(e)) + + # Wraps outputs in Identity ops. Otherwise a replicated input copied + # straight to an output would bypass the replicate(). This would be bad + # because the TPUReplicatedInput/TPUReplicatedOutput operator would not + # be rewritten away, leading to a runtime error. + # TODO(phawkins): extend the rewrite to elide these nodes instead. + with ops.device(core(0)): + o = array_ops.identity(o) + # pylint: disable=protected-access + o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) + # pylint: enable=protected-access + flat_outputs[i] = array_ops.identity(o) + + # All flat_outputs are Tensors, and no Operations. + return flat_outputs, [] def split_compile_and_shard(computation, @@ -809,9 +1063,6 @@ def split_compile_and_shard(computation, return x + 3 ... = shard(computation, ...) - TODO(phawkins): consider adding support for broadcasting Tensors passed - as inputs. - If `outputs_from_all_shards` is true, the outputs from all shards of `computation` are concatenated back together along their `output_shards_axes`. Otherwise, each output is taken from an arbitrary shard. @@ -853,6 +1104,8 @@ def split_compile_and_shard(computation, ValueError: If len(input_shard_axes) != len(inputs) ValueError: If len(output_shard_axes) != len(outputs from `computation`) """ + # TODO(phawkins): consider adding support for broadcasting Tensors passed as + # inputs. if num_shards <= 0: raise ValueError("num_shards must be a positive integer.") @@ -1092,6 +1345,11 @@ def rewrite(computation, All `Operation`s constructed during `computation` will be executed when evaluating any of the returned output tensors, not just the ones returned. inputs: A list of input tensors or `None` (equivalent to an empty list). + Each input can be a nested structure containing values that are + convertible to tensors. Note that passing an N-dimension list of + compatible values will result in a N-dimention list of scalar tensors + rather than a single Rank-N tensors. If you need different behavior, + convert part of inputs to tensors with `tf.convert_to_tensor`. infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to `computation`. device_assignment: if not `None`, a `DeviceAssignment` describing the @@ -1100,11 +1358,15 @@ def rewrite(computation, case the core attached to task 0, TPU device 0 is used. name: (Deprecated) Does nothing. Returns: - A list of output tensors. + Same data structure as if computation(*inputs) is called directly with some + exceptions for correctness. Exceptions include: + 1) None output: a NoOp would be returned which control-depends on + computation. + 2) Single value output: A tuple containing the value would be returned. + 3) Operation-only outputs: a NoOp would be returned which + control-depends on computation. + TODO(b/121383831): Investigate into removing these special cases. """ - if inputs is not None and not isinstance(inputs, (list, tuple)): - raise TypeError("tpu.rewrite() inputs must be a list or tuple") - # TODO(b/36647078) remove disable when pylint bug is fixed. # pylint: disable=indexing-exception return replicate( diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index 672462447944b777375331d49727c4d5366cf295..ed1e0f0401a96c34e6ff9323685857b64e10bd14 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -21,6 +21,7 @@ from __future__ import print_function from contextlib import contextmanager import copy +from tensorflow.contrib.tpu.python.tpu import _tpu_estimator_embedding from tensorflow.contrib.tpu.python.tpu import device_assignment as tpu_device_assignment from tensorflow.contrib.tpu.python.tpu import tpu_config from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib @@ -192,8 +193,14 @@ class _InternalTPUContext(object): ``` """ - def __init__(self, config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu, eval_on_tpu=True): + def __init__(self, + config, + train_batch_size, + eval_batch_size, + predict_batch_size, + use_tpu, + eval_on_tpu=True, + embedding_config_spec=None): self._config = config self._train_batch_size = train_batch_size self._eval_batch_size = eval_batch_size @@ -208,7 +215,7 @@ class _InternalTPUContext(object): use_tpu and config.tpu_config.num_cores_per_replica) self._mode = None num_cores_per_replica = config.tpu_config.num_cores_per_replica - if num_cores_per_replica: + if self._model_parallelism_enabled: self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[ num_cores_per_replica] else: @@ -216,6 +223,8 @@ class _InternalTPUContext(object): self._lazy_tpu_system_metadata_dict = {} # key by master address self._lazy_device_assignment_dict = {} # key by master address self._lazy_validation_dict = {} # key by ModeKeys + self._embedding_config_spec = embedding_config_spec + self._lazy_embedding_config_dict = {} # key by master address def _assert_mode(self): if self._mode is None: @@ -293,6 +302,30 @@ class _InternalTPUContext(object): self._lazy_device_assignment_dict[master] = device_assignment return device_assignment + @property + def embedding_config(self): + """Returns the embedding config based on current mode.""" + master = self._get_master_address() + if master in self._lazy_embedding_config_dict: + embedding_config = self._lazy_embedding_config_dict[master] + else: + embedding_config = None + if self._use_tpu and self._embedding_config_spec: + embedding_config = _tpu_estimator_embedding.EmbeddingConfig( + self._embedding_config_spec, self._train_batch_size, + self._eval_batch_size, self.num_hosts, self.num_cores, master) + if not embedding_config.has_embedding_tables(): + embedding_config = None + self._lazy_embedding_config_dict[master] = embedding_config + + if embedding_config is not None: + mode = self._assert_mode() + # Dynamically attach tpu_embedding based on mode. With + # this, we could keep embedding_config immutable but call site always + # accesses the unified API '.tpu_embedding'. + embedding_config.tpu_embedding = embedding_config.get_tpu_embedding(mode) + return embedding_config + @property def model_parallelism_enabled(self): return self._model_parallelism_enabled @@ -710,11 +743,15 @@ class _OneCoreTPUContext(_InternalTPUContext): def _get_tpu_context(config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu, eval_on_tpu): + predict_batch_size, use_tpu, eval_on_tpu, + embedding_config_spec): """Returns an instance of `_InternalTPUContext`.""" if (config.tpu_config.num_shards == 1 and config.tpu_config.num_cores_per_replica is None): + if embedding_config_spec is not None: + raise ValueError('Setting TPUConfig.num_shards==1 is unsupported ' + 'when embedding_config_spec is not None.') logging.warning( 'Setting TPUConfig.num_shards==1 is an unsupported behavior. ' 'Please fix as soon as possible (leaving num_shards as None.)') @@ -722,4 +759,5 @@ def _get_tpu_context(config, train_batch_size, eval_batch_size, predict_batch_size, use_tpu) return _InternalTPUContext(config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu, eval_on_tpu) + predict_batch_size, use_tpu, eval_on_tpu, + embedding_config_spec) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py index ccba8a46c7cad0337119672e02314684f4451479..04e7397162624dfc1f6203dd267c1c1b90163dd4 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py @@ -28,6 +28,7 @@ from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.tpu.ops import gen_tpu_ops from tensorflow.contrib.tpu.proto import tpu_embedding_configuration_pb2 as elc from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -43,19 +44,6 @@ from tensorflow.python.ops import variables TRAINING = elc.TPUEmbeddingConfiguration.TRAINING INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE -# TODO(shizhiw): A better interface is to make `num_hosts` and -# `num_cores_per_host` optional parameters for `TPUEmbedding` -# constructor. Usually they can be automatically detected, but -# user can also specify them for debugging (b/112112496). -# Auto-detection can be done with `tpu_system_metadata.py`. -_MASTER_JOB = 'tpu_worker' -_HOST_PATTERN = '/job:tpu_worker/task:{}/device:CPU:0' -_NUM_CORES_PER_HOST = 8 - -_TEST_MASTER_JOB = None -_TEST_HOST = '/replica:0/task:0/device:CPU:0' -_TEST_NUM_CORES_PER_HOST = 2 - class TableConfig( collections.namedtuple( @@ -301,10 +289,9 @@ class TPUEmbedding(object): table_to_config_dict, feature_to_table_dict, batch_size, - num_hosts, mode, - optimization_parameters=None, - tpu_embedding_test=False): + master, + optimization_parameters=None): """API for using TPU for embedding lookups. Args: @@ -315,12 +302,11 @@ class TPUEmbedding(object): to string of table name. Feature refers to ids to lookup in embedding table, e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`. batch_size: An `int` representing the global batch size. - num_hosts: An `int` representing the number of TPU hosts. mode: `TRAINING` or `INFERENCE`. + master: A `string` representing the TensorFlow master to use. optimization_parameters: `AdagradParameters`, `AdamParameters`, `Stochasticgradientdescentparameters`. Must be set in training and must be `None` in inference. - tpu_embedding_test: A `bool`. Only used for testing. Raises: ValueError: if any input is invalid. @@ -337,15 +323,17 @@ class TPUEmbedding(object): self._batch_size = batch_size - if tpu_embedding_test: - self._num_hosts = 1 - self._hosts = [_TEST_HOST] - self._num_cores_per_host = _TEST_NUM_CORES_PER_HOST - else: - self._num_hosts = num_hosts - self._hosts = [_HOST_PATTERN.format(i) for i in range(self._num_hosts)] - self._num_cores_per_host = _NUM_CORES_PER_HOST - self._num_cores = self._num_cores_per_host * self._num_hosts + self._master = master + self._tpu_system_metadata = ( + tpu_system_metadata_lib._query_tpu_system_metadata(self._master)) # pylint: disable=protected-access + if self._tpu_system_metadata.num_cores == 0: + raise ValueError('TPUEmbedding needs TPUs, but master {} does not have ' + 'TPUs.'.format(self._master)) + self._num_hosts = self._tpu_system_metadata.num_hosts + self._hosts = [device.name for device in self._tpu_system_metadata.devices + if 'device:CPU:' in device.name] + self._num_cores_per_host = self._tpu_system_metadata.num_of_cores_per_host + self._num_cores = self._tpu_system_metadata.num_cores _validate_batch_size(self._batch_size, self._num_cores) self._batch_size_per_core = self._batch_size // self._num_cores @@ -389,7 +377,7 @@ class TPUEmbedding(object): Returns: A list of device names for CPU hosts. """ - return self._hosts + return copy.copy(self._hosts) # TODO(shizhiw): change to num_tensor_cores_per_host to be more explicit and # to be consistent with `tpu_embedding_configuration.proto`. @@ -452,6 +440,10 @@ class TPUEmbedding(object): def table_to_table_variables_dict(self): return copy.copy(self._table_to_table_variables_dict) + @property + def feature_to_table_dict(self): + return copy.copy(self._feature_to_table_dict) + def get_slot_names(self): """Return a list of the names of slots created by `TPUEmbedding`.""" return self._optimizer_handler.get_slot_names() @@ -1077,34 +1069,3 @@ def _create_partitioned_variables(name, initializer=initializer, collections=collections, trainable=False)) - - -@ops.RegisterGradient('TPUEmbeddingActivations') -def _embedding_activations_grad(activations_op, grad_wrt_activations): - """Saves the gradient of embedding activations ops in a graph collection.""" - g = ops.get_default_graph() - table_id = activations_op.get_attr('table_id') - lookup_id = activations_op.get_attr('lookup_id') - table_gradients = g.get_collection_ref( - 'tpu_embedding_gradients_table_%d' % table_id) - - if not table_gradients: - raise RuntimeError( - 'Gradients for TPUEmbedding have been generated in non-training mode. ' - 'This is not expected. Consider putting your Optimizer.minimize code ' - 'behind the training mode condition check. For Estimator, you can ' - 'do \n\n' - ' if mode == tf.estimator.ModeKeys.TRAIN:\n' - ' train_op = opt.minimize(loss)\n' - '\n') - - table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations) - return [ - # RegisterGradient requires that value be returned for all inputs. Since - # the first argument (tpu_gradient_variable_{table_name}) has shape [1], - # we will return zeros(shape=[1]). The actual gradient w.r.t. the - # embedding activations (grad_wrt_activations) has the same shape as the - # activations returned by embedding_activations. - array_ops.zeros(arg.shape, dtype=dtypes.float32) - for arg in activations_op.inputs - ] diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 87a970f0523363426b0da5b12838b797d7f8bebb..0620598ea00316d112245fa17bf5e56b1a015af4 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -31,17 +31,22 @@ import six from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.contrib.tpu.ops import gen_tpu_ordinal_selector_op from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result -from tensorflow.contrib.tpu.python.tpu import tensor_tracer from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import _tpu_estimator_embedding from tensorflow.contrib.tpu.python.tpu import error_handling +from tensorflow.contrib.tpu.python.tpu import functional as tpu_functional from tensorflow.contrib.tpu.python.tpu import session_support +from tensorflow.contrib.tpu.python.tpu import tensor_tracer from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_config from tensorflow.contrib.tpu.python.tpu import tpu_context from tensorflow.contrib.tpu.python.tpu import tpu_feed from tensorflow.contrib.tpu.python.tpu import training_loop from tensorflow.contrib.tpu.python.tpu import util as util_lib +from tensorflow.contrib.tpu.python.tpu._tpu_estimator_embedding import AdamParameters # pylint: disable=unused-import +from tensorflow.contrib.tpu.python.tpu._tpu_estimator_embedding import EmbeddingConfigSpec # pylint: disable=unused-import from tensorflow.contrib.training.python.training import hparam from tensorflow.core.framework import variable_pb2 from tensorflow.core.framework.summary_pb2 import Summary @@ -55,6 +60,7 @@ from tensorflow.python.estimator.export import export_output as export_output_li from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -90,6 +96,7 @@ _ONE_GIGABYTE = 1024 * 1024 * 1024 _TPU_ENQUEUE_OPS = '_tpu_enqueue_ops' _TPU_TRAIN_OP = '_tpu_train_op' _REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference' +_KEY_WHEN_PREDICTIONS_IS_A_TENSOR = '_key_when_predictions_is_a_tensor' # Ideally _USE_TPU_KEY should be reserved as well. However there are already # models that make use of this key, thus it can not be reserved now to prevent @@ -120,6 +127,16 @@ def _is_iterable(obj): return False +class CatchInvalidHostcallFunctions(control_flow_ops.XLAControlFlowContext): + + def AddOp(self, op): + if op.type in [ + 'AudioSummary', 'AudioSummaryV2', 'HistogramSummary', 'ImageSummary', + 'MergeSummary', 'ScalarSummary', 'TensorSummary', 'TensorSummaryV2' + ]: + raise ValueError('Use tf.contrib.summary inside of host_calls.') + + def _create_global_step(graph): graph = graph or ops.get_default_graph() if training.get_global_step(graph) is not None: @@ -427,13 +444,20 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): run_infeed_loop_on_coordinator=True, rendezvous=None, master=None, - session_config=None): + session_config=None, + tpu_init_ops=None): self._master_job = ctx.master_job self._enqueue_ops = enqueue_ops self._dequeue_ops = dequeue_ops self._rendezvous = rendezvous self._master = master self._session_config = session_config + self._init_ops = list(tpu_init_ops or []) + if ctx.embedding_config is None: + self._embedding_layer_config = None + else: + self._embedding_layer_config = ( + ctx.embedding_config.tpu_embedding.config_proto) self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator self._initial_infeed_sleep_secs = ( ctx.config.tpu_config.initial_infeed_sleep_secs) @@ -446,7 +470,6 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): def begin(self): logging.info('TPU job name %s', self._master_job) self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - self._init_ops = [] if self._should_initialize_tpu: self._finalize_ops = [tpu.shutdown_system(job=self._master_job)] else: @@ -506,7 +529,10 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): with ops.Graph().as_default(): with tf_session.Session( self._master, config=self._session_config) as sess: - sess.run(tpu.initialize_system(job=self._master_job)) + sess.run( + tpu.initialize_system( + job=self._master_job, + embedding_config=self._embedding_layer_config)) logging.info('Initialized TPU in %d seconds', time.time() - start) session.run(self._init_ops, @@ -848,6 +874,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( """Generates the per_host enqueue ops.""" control_deps = [] per_host_sharded_inputs = [] + sparse_features_list = [] num_replicas_per_host = ctx.num_of_replicas_per_host cached_signals = None with ops.device(device): @@ -866,6 +893,10 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( else: cached_signals = signals + features, labels, sparse_features = ( + _tpu_estimator_embedding.split_inputs(ctx, features, labels)) + sparse_features_list.append(sparse_features) + inputs_structure_recorder.validate_and_record_structure( features, labels) flattened_inputs = ( @@ -894,6 +925,11 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( tpu_ordinal_function=tpu_ordinal_function_impl) captured_infeed_queue.capture(infeed_queue) + if ctx.embedding_config: + per_host_enqueue_ops.extend( + ctx.embedding_config.tpu_embedding.generate_enqueue_ops( + sparse_features_list)) + if signals is None: return per_host_enqueue_ops else: @@ -1303,6 +1339,44 @@ class _InputPipeline(object): logging.warn(err_msg) +def call_computation(computation, + experimental_exported_model_uses_all_cores=True): + """Call computation. + + computation uses a single-core for TPU inference. If + `experimental_exported_model_uses_all_cores` is `True`, this function will + round-robin + computation among all TPU cores visible to the host; otherwise, it will use + a single core. + + Args: + computation: A Python function that takes no inputs and builds computation + graph. If `computation` returns m outputs, this function will return a + list of m Tensors. + experimental_exported_model_uses_all_cores: Whether to round-robin among all + cores visible to the host, or to use a single core. + + Returns: + A list of output tensors. + """ + if experimental_exported_model_uses_all_cores: + # Using `TPUPartitionedCall` makes it possible to target a different + # TPU core with every `Session.run()` call. Note that the entire inference + # graph executes on a single core, and that invocations of this graph + # will round-robin among the cores attached to a host. + @function.Defun() + def tpu_subgraph(): + return computation() + + return tpu_functional.TPUPartitionedCall( + args=tpu_subgraph.captured_inputs, + device_ordinal=gen_tpu_ordinal_selector_op.tpu_ordinal_selector(), + Tout=[o.type for o in tpu_subgraph.definition.signature.output_arg], + f=tpu_subgraph) + else: + return computation() + + class _ModelFnWrapper(object): """A `model_fn` wrapper. @@ -1322,6 +1396,12 @@ class _ModelFnWrapper(object): def call_without_tpu(self, features, labels, is_export_mode): return self._call_model_fn(features, labels, is_export_mode=is_export_mode) + def _add_embedding_features(self, features): + if self._ctx.embedding_config: + tpu_embedding_ = self._ctx.embedding_config.tpu_embedding + embedding_activations = tpu_embedding_.get_activations() + features.update(embedding_activations) + def convert_to_single_tpu_train_step(self, dequeue_fn): """Converts user provided model_fn` as a single train step on TPU. @@ -1354,6 +1434,7 @@ class _ModelFnWrapper(object): del loss # unused; required in function signature. inputs = dequeue_fn() features, labels = inputs.features_and_labels() + self._add_embedding_features(features) estimator_spec = self._verify_estimator_spec( self._call_model_fn(features, labels)) @@ -1370,11 +1451,19 @@ class _ModelFnWrapper(object): if tensor_tracer.TensorTracer.is_enabled(): tt = tensor_tracer.TensorTracer() loss, tracing_ops = tt.trace_tpu(ops.get_default_graph(), loss, - self._ctx.num_replicas) + self._ctx.num_replicas, + fetches=[loss, train_op]) + + if self._ctx.embedding_config is None: + apply_sparse_grads = [] + else: + tpu_embedding_ = self._ctx.embedding_config.tpu_embedding + apply_sparse_grads = [tpu_embedding_.generate_send_gradients_op()] # We must run train_op to update the variables prior to running the # outfeed. - with ops.control_dependencies([train_op]+tracing_ops): + with ops.control_dependencies([train_op] + tracing_ops + + apply_sparse_grads): host_call_outfeed_ops = [] if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec) # pylint: disable=protected-access and estimator_spec.host_call is not None): @@ -1420,6 +1509,7 @@ class _ModelFnWrapper(object): """Evaluation step function for use inside a while loop.""" inputs = dequeue_fn() features, labels = inputs.features_and_labels() + self._add_embedding_features(features) tpu_estimator_spec = self._call_model_fn(features, labels) if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access @@ -1759,6 +1849,10 @@ class _OutfeedHostCall(object): dequeue_ops[j].append(item) # Deconstruct dequeue ops. + flat_dequeue_ops = [] + for l in dequeue_ops: + flat_dequeue_ops.extend(l) + dequeue_ops_by_name = {} pos = 0 for name in self._names: @@ -1766,6 +1860,14 @@ class _OutfeedHostCall(object): len(self._tensors[name])] pos += len(self._tensors[name]) + def _call_host_fn(fn, *args, **kw): + context = CatchInvalidHostcallFunctions() + context.Enter() + result = fn(*args, **kw) + context.Exit() + context.ExitResult(result) + return result + # It is assumed evaluation always happens on single host TPU system. So, # place all ops on tpu host if possible. # @@ -1799,7 +1901,7 @@ class _OutfeedHostCall(object): # The user-provided eval_metrics[1] is a dict. dequeue_ops = dict(zip(self._tensor_keys[name], dequeue_ops)) try: - ret[name] = self._host_fns[name](**dequeue_ops) + ret[name] = _call_host_fn(self._host_fns[name], **dequeue_ops) except TypeError as e: logging.warning( 'Exception while calling %s: %s. It is likely the tensors ' @@ -1807,8 +1909,10 @@ class _OutfeedHostCall(object): 'function\'s arguments', name, e, name) raise else: - ret[name] = self._host_fns[name](*dequeue_ops) + ret[name] = _call_host_fn(self._host_fns[name], *dequeue_ops) + # force all dequeue operations to be run if not consumed by the host calls + ret['__force_dequeue'] = control_flow_ops.group(*flat_dequeue_ops) return ret @@ -2100,7 +2204,11 @@ class TPUEstimator(estimator_lib.Estimator): batch_axis=None, eval_on_tpu=True, export_to_tpu=True, - warm_start_from=None): + export_to_cpu=True, + warm_start_from=None, + experimental_exported_model_uses_all_cores=False, + experimental_export_device_assignment=False, + experimental_embedding_config_spec=None): """Constructs an `TPUEstimator` instance. Args: @@ -2143,12 +2251,29 @@ class TPUEstimator(estimator_lib.Estimator): eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`. export_to_tpu: If True, `export_savedmodel()` exports a metagraph for - serving on TPU besides the one on CPU. + serving on TPU. Note that unsupported export modes such as EVAL will be + ignored. For those modes, only a CPU model will be exported. + Currently, export_to_tpu only supports PREDICT. + export_to_cpu: If True, `export_savedmodel()` exports a metagraph for + serving on CPU. warm_start_from: Optional string filepath to a checkpoint or SavedModel to warm-start from, or a `tf.estimator.WarmStartSettings` object to fully configure warm-starting. If the string filepath is provided instead of a `WarmStartSettings`, then all variables are warm-started, and it is assumed that vocabularies and Tensor names are unchanged. + experimental_exported_model_uses_all_cores: Whether to round-robin among + all cores visible to the host which is serving the saved model, or to + use a single core. This is a temporary flag to enable using all TPU + cores for inference with TPUPartitionedCall(). Once outside compilation + is supported in TPUPartitionedCall(), this flag will be enabled by + default. + experimental_export_device_assignment: Whether to include the device + assignment in the exported model. Doing so is useful in case of model + parallel inference but will tie the exported model to the TPU topology + used to export the model. + experimental_embedding_config_spec: Optional EmbeddingConfigSpec instance + to support using TPU embedding. IT IS STILL WORK IN PROGRESS, SO PLEASE + DO NOT USE. Raises: ValueError: `params` has reserved keys already. @@ -2210,9 +2335,19 @@ class TPUEstimator(estimator_lib.Estimator): # pylint: disable=protected-access self._ctx = tpu_context._get_tpu_context( self._config, train_batch_size, eval_batch_size, predict_batch_size, - use_tpu, eval_on_tpu) + use_tpu, eval_on_tpu, experimental_embedding_config_spec) + self._export_to_cpu = export_to_cpu self._export_to_tpu = export_to_tpu + self._experimental_exported_model_uses_all_cores = ( + experimental_exported_model_uses_all_cores) + self._experimental_export_device_assignment = ( + experimental_export_device_assignment) + if (experimental_exported_model_uses_all_cores and + experimental_export_device_assignment): + raise ValueError('experimental_exported_model_uses_all_cores and ' + 'experimental_export_device_assignment is not supported ' + 'at the same time.') self._is_input_fn_invoked = None self._rendezvous = {} @@ -2226,35 +2361,43 @@ class TPUEstimator(estimator_lib.Estimator): export_tags=None, check_variables=True): if self._export_to_tpu and mode != model_fn_lib.ModeKeys.PREDICT: - raise NotImplementedError( - 'TPUEstimator only handles mode PREDICT for exporting ' - 'when `export_to_tpu` is `True`; ' - 'got {}.'.format(mode)) - - (super(TPUEstimator, self)._add_meta_graph_for_mode( - builder, - input_receiver_fn_map, - checkpoint_path, - save_variables, - mode=mode, - export_tags=export_tags, - check_variables=check_variables)) + logging.warning('TPUEstimator only handles mode PREDICT for exporting ' + 'when `export_to_tpu` is `True`; Mode {} will be ignored ' + 'for TPU.'.format(mode)) + + if not self._export_to_cpu and not self._export_to_tpu: + raise ValueError('One of export_to_cpu and export_to_tpu must be true.') + + if self._export_to_cpu: + (super(TPUEstimator, self)._add_meta_graph_for_mode( + builder, + input_receiver_fn_map, + checkpoint_path, + save_variables, + mode=mode, + export_tags=export_tags, + check_variables=check_variables)) - if self._export_to_tpu: + if self._export_to_tpu and mode == model_fn_lib.ModeKeys.PREDICT: input_receiver_fn_map = { _REWRITE_FOR_INFERENCE_MODE: input_receiver_fn_map[mode] } export_tags = [tag_constants.SERVING, tag_constants.TPU] mode = _REWRITE_FOR_INFERENCE_MODE + # See b/110052256 for why `check_variables` is `False`. + if not self._export_to_cpu: + check_variables = save_variables = True + else: + check_variables = save_variables = False (super(TPUEstimator, self)._add_meta_graph_for_mode( builder, input_receiver_fn_map, checkpoint_path, - save_variables=False, + save_variables=save_variables, mode=mode, export_tags=export_tags, - check_variables=False)) + check_variables=check_variables)) def _call_model_fn(self, features, labels, mode, config): if mode == _REWRITE_FOR_INFERENCE_MODE: @@ -2269,6 +2412,88 @@ class TPUEstimator(estimator_lib.Estimator): raise ValueError('mode must be {}; ' 'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode)) + computation, capture = self._build_computation_for_inference( + features, labels, mode, config) + tensors = call_computation( + computation, + experimental_exported_model_uses_all_cores=self + ._experimental_exported_model_uses_all_cores) + estimator_spec, export_outputs_dict, predictions_dict, none_indices = ( + capture.get()) + predictions_list = tensors[:len(predictions_dict)] + export_outputs_list_without_none = tensors[len(predictions_dict):] + + # Reinsert `None`s which we've taken out in + # `_build_computation_for_inference()`. + export_outputs_list = [] + while none_indices or export_outputs_list_without_none: + if none_indices and none_indices[0] == len(export_outputs_list): + export_outputs_list.append(None) + none_indices.pop(0) + else: + export_outputs_list.append(export_outputs_list_without_none.pop(0)) + + # Reconstruct `export_outputs` with updated tensors. + new_export_outputs_dict = nest.pack_sequence_as(export_outputs_dict, + export_outputs_list) + export_outputs = estimator_spec.export_outputs + new_export_outputs = collections.OrderedDict( + (k, _clone_export_output_with_tensors(export_outputs[k], v)) + for k, v in six.iteritems(new_export_outputs_dict)) + # Reconstruct `predictions` with updated tensors. + new_predictions = nest.pack_sequence_as(predictions_dict, predictions_list) + if (len(new_predictions) == 1 and + _KEY_WHEN_PREDICTIONS_IS_A_TENSOR in new_predictions): + new_predictions = new_predictions[_KEY_WHEN_PREDICTIONS_IS_A_TENSOR] + + return estimator_spec._replace( + export_outputs=new_export_outputs, predictions=new_predictions) + + def _build_computation_for_inference(self, features, labels, mode, config): + capture = _CapturedObject() + + def computation(): + """Computation to be passed to `TPUPartitionedCall()`.""" + tpu_computation, tpu_capture = self._build_tpu_computation_for_inference( + features, labels, mode, config) + + if self._experimental_export_device_assignment: + # Export the device assignment as part of the model. This is useful for + # model parallel usecases where the model relies on the mapping between + # logical and physical devices. + with self._ctx.with_mode(mode) as ctx: + device_assignment = ctx.device_assignment + else: + device_assignment = None + tensors_on_cpu = tpu.rewrite_for_inference( + tpu_computation, device_assignment=device_assignment) + (estimator_spec, export_outputs_dict, export_outputs_list, + predictions_dict) = ( + tpu_capture.get()) + predictions_list = tensors_on_cpu[:len(predictions_dict)] + export_outputs_tpu_on_cpu_list = tensors_on_cpu[len(predictions_dict):] + + # Reconstruct tensors used in export_outputs, with TPU tensors replaced + # with their CPU counterpart returned from `rewrite_for_inference()`. + # `function.Defun()` does not like `None`s in return values, so we leave + # `None`s out but record their positions for later reconstruction. + export_outputs_list_without_none = [] + none_indices = [] + for i, t in enumerate(export_outputs_list): + if t is None: + none_indices.append(i) + else: + export_outputs_list_without_none.append( + export_outputs_tpu_on_cpu_list.pop(0)) + + capture.capture((estimator_spec, export_outputs_dict, predictions_dict, + none_indices)) + return predictions_list + export_outputs_list_without_none + + return computation, capture + + def _build_tpu_computation_for_inference(self, features, labels, mode, + config): capture = _CapturedObject() def computation(): @@ -2289,38 +2514,30 @@ class TPUEstimator(estimator_lib.Estimator): # We pick the TPU tensors out from `export_output` and later return them # from `computation` for rewriting. - tensors_dict = collections.OrderedDict( + export_outputs_dict = collections.OrderedDict( (k, _export_output_to_tensors(v)) for k, v in six.iteritems(estimator_spec.export_outputs)) - tensors = nest.flatten(tensors_dict) - tpu_tensors = [t for t in tensors if t is not None] - - # We cannot return anything other than `tpu_tensors` here so we capture - # the rest for later use. - capture.capture((estimator_spec, tensors_dict, tensors)) - return tpu_tensors - - tpu_tensors_on_cpu = tpu.rewrite_for_inference(computation) - estimator_spec, tensors_dict, tensors = capture.get() - - # Reconstruct `tensors`, but with `tpu_tensors` replaced with - # `tpu_tensors_on_cpu`. - new_tensors = [] - for t in tensors: - if t is None: - new_tensors.append(None) + export_outputs_list = nest.flatten(export_outputs_dict) + export_outputs_tpu_list = [ + t for t in export_outputs_list if t is not None + ] + + if isinstance(estimator_spec.predictions, dict): + predictions_dict = collections.OrderedDict( + (k, v) for k, v in six.iteritems(estimator_spec.predictions)) else: - new_tensors.append(tpu_tensors_on_cpu.pop(0)) + predictions_dict = { + _KEY_WHEN_PREDICTIONS_IS_A_TENSOR: estimator_spec.predictions + } + predictions_list = nest.flatten(predictions_dict) - # Reconstruct `tensors_dict`. - new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors) - # Reconstruct `export_outputs`. - export_outputs = estimator_spec.export_outputs - new_export_outputs = collections.OrderedDict( - (k, _clone_export_output_with_tensors(export_outputs[k], v)) - for k, v in six.iteritems(new_tensors_dict)) + # We cannot return everything we want through the return values, so + # capture the rest here for later use. + capture.capture((estimator_spec, export_outputs_dict, export_outputs_list, + predictions_dict)) + return predictions_list + export_outputs_tpu_list - return estimator_spec._replace(export_outputs=new_export_outputs) + return computation, capture def _create_global_step(self, graph): """Creates a global step suitable for TPUs. @@ -2538,7 +2755,11 @@ class TPUEstimator(estimator_lib.Estimator): if self._log_every_n_steps is not None: examples_hook = ExamplesPerSecondHook( ctx.global_batch_size, - output_dir=self.model_dir, + # pylint:disable=g-long-ternary + output_dir=(self.model_dir + if not config or config.save_summary_steps + else None), + # pylint:enable=g-long-ternary every_n_steps=self._log_every_n_steps) if ctx.is_running_on_cpu(is_export_mode=is_export_mode): @@ -2555,6 +2776,10 @@ class TPUEstimator(estimator_lib.Estimator): assert callable(features), '`input_fn` is not callable.' input_fn = features + tpu_init_ops = [] + if ctx.embedding_config: + tpu_init_ops.extend(ctx.embedding_config.tpu_embedding.init_ops) + input_holders = _InputPipeline(input_fn, batch_axis, ctx) enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = ( input_holders.generate_infeed_enqueue_ops_and_dequeue_fn()) @@ -2608,7 +2833,7 @@ class TPUEstimator(estimator_lib.Estimator): rendezvous=self._rendezvous[mode], master=self._config.master, session_config=self._session_config, - ), + tpu_init_ops=tpu_init_ops), InstallSignalHandlerHook() ]) if self._log_every_n_steps is not None: @@ -2645,6 +2870,10 @@ class TPUEstimator(estimator_lib.Estimator): with ops.control_dependencies([loss]): update_ops = _sync_variables_ops(ctx) + if ctx.embedding_config: + update_ops.extend( + ctx.embedding_config.tpu_embedding.retrieve_parameters_ops) + # Validate the TPU training graph to catch basic errors _validate_tpu_training_graph() @@ -2714,7 +2943,8 @@ class TPUEstimator(estimator_lib.Estimator): rendezvous=self._rendezvous[mode], master=self._config.evaluation_master, session_config=self._session_config, - )] + input_hooks + tpu_init_ops=tpu_init_ops) + ] + input_hooks if eval_hooks: hooks.extend(eval_hooks) diff --git a/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc b/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc index 76cb5531cd0bc3a375d1434c31fa14a9d7f42476..d98e0b7a5ed52c00a8cf2b1a1bbc53f1b1cd28c7 100644 --- a/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc +++ b/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc @@ -134,12 +134,16 @@ Status GetGradientAccumulationSupport(OptimizationAlgorithm alg, } } namespace { -// Make a normal state variable specification. +// Make a normal state variable specification. Please refer to +// //third_party/tensorflow/contrib/tpu/proto/optimization_parameters.proto +// (StateVariableSpecification message) for instructions on how to set the +// padding_initial_value field. StateVariableSpecification MakeStandardStateVariableSpecification( - const string& name) { + const string& name, double padding_initial_value) { StateVariableSpecification result; result.set_name(name); - result.mutable_user_defined(); + result.mutable_user_defined()->set_padding_initial_value( + padding_initial_value); return result; } } // namespace @@ -149,14 +153,14 @@ Status GetOptimizationAlgorithmStateVariables( std::vector* state_variables) { // The first parameter set is always the weights themselves. state_variables->push_back( - MakeStandardStateVariableSpecification("parameters")); + MakeStandardStateVariableSpecification("parameters", 0.0)); // The order of the returned parameters needs to match the offsets used by // the algorithm implementations in test_util.cc and // address_handler_program_creator.cc. switch (alg) { case OptimizationAlgorithm::kAdagrad: { state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators")); + MakeStandardStateVariableSpecification("accumulators", 0.1)); break; } case OptimizationAlgorithm::kStochasticGradientDescent: { @@ -165,53 +169,58 @@ Status GetOptimizationAlgorithmStateVariables( } case OptimizationAlgorithm::kFtrl: { state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators")); + MakeStandardStateVariableSpecification("accumulators", 0.1)); state_variables->push_back( - MakeStandardStateVariableSpecification("linears")); + MakeStandardStateVariableSpecification("linears", 0.0)); break; } case OptimizationAlgorithm::kAdam: { state_variables->push_back( - MakeStandardStateVariableSpecification("momenta")); + MakeStandardStateVariableSpecification("momenta", 0.0)); state_variables->push_back( - MakeStandardStateVariableSpecification("velocities")); + MakeStandardStateVariableSpecification("velocities", 0.0)); break; } case OptimizationAlgorithm::kMomentum: { state_variables->push_back( - MakeStandardStateVariableSpecification("momenta")); + MakeStandardStateVariableSpecification("momenta", 0.0)); break; } case OptimizationAlgorithm::kRmsProp: { - state_variables->push_back(MakeStandardStateVariableSpecification("ms")); - state_variables->push_back(MakeStandardStateVariableSpecification("mom")); + state_variables->push_back( + MakeStandardStateVariableSpecification("ms", 1.0)); + state_variables->push_back( + MakeStandardStateVariableSpecification("mom", 0.0)); break; } case OptimizationAlgorithm::kCenteredRmsProp: { - state_variables->push_back(MakeStandardStateVariableSpecification("ms")); - state_variables->push_back(MakeStandardStateVariableSpecification("mom")); - state_variables->push_back(MakeStandardStateVariableSpecification("mg")); + state_variables->push_back( + MakeStandardStateVariableSpecification("ms", 1.0)); + state_variables->push_back( + MakeStandardStateVariableSpecification("mom", 0.0)); + state_variables->push_back( + MakeStandardStateVariableSpecification("mg", 0.0)); break; } case OptimizationAlgorithm::kMdlAdagradLight: { state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators")); + MakeStandardStateVariableSpecification("accumulators", 0.1)); state_variables->push_back( - MakeStandardStateVariableSpecification("weights")); + MakeStandardStateVariableSpecification("weights", 0.0)); state_variables->push_back( - MakeStandardStateVariableSpecification("benefits")); + MakeStandardStateVariableSpecification("benefits", 0.0)); break; } case OptimizationAlgorithm::kAdadelta: { state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators")); + MakeStandardStateVariableSpecification("accumulators", 0.0)); state_variables->push_back( - MakeStandardStateVariableSpecification("updates")); + MakeStandardStateVariableSpecification("updates", 0.0)); break; } case OptimizationAlgorithm::kProximalAdagrad: { state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators")); + MakeStandardStateVariableSpecification("accumulators", 0.1)); break; } case OptimizationAlgorithm::PARAMETERS_NOT_SET: { diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index bcc177601b95172b05d327247bd370c2f8b65d59..27f0d9b2e38c433d4fb4573285ecb8c9946112e8 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -499,6 +499,7 @@ class HParams(object): value: New value of the hyperparameter. Raises: + KeyError: If the hyperparameter doesn't exist. ValueError: If there is a type mismatch. """ param_type, is_list = self._hparam_types[name] @@ -517,6 +518,8 @@ class HParams(object): def del_hparam(self, name): """Removes the hyperparameter with key 'name'. + Does nothing if it isn't present. + Args: name: Name of the hyperparameter. """ @@ -525,19 +528,20 @@ class HParams(object): del self._hparam_types[name] def parse(self, values): - """Override hyperparameter values, parsing new values from a string. + """Override existing hyperparameter values, parsing new values from a string. See parse_values for more detail on the allowed format for values. Args: - values: String. Comma separated list of `name=value` pairs where - 'value' must follow the syntax described above. + values: String. Comma separated list of `name=value` pairs where 'value' + must follow the syntax described above. Returns: The `HParams` instance. Raises: - ValueError: If `values` cannot be parsed. + ValueError: If `values` cannot be parsed or a hyperparameter in `values` + doesn't exist. """ type_map = dict() for name, t in self._hparam_types.items(): @@ -548,7 +552,7 @@ class HParams(object): return self.override_from_dict(values_map) def override_from_dict(self, values_dict): - """Override hyperparameter values, parsing new values from a dictionary. + """Override existing hyperparameter values, parsing new values from a dictionary. Args: values_dict: Dictionary of name:value pairs. @@ -557,6 +561,7 @@ class HParams(object): The `HParams` instance. Raises: + KeyError: If a hyperparameter in `values_dict` doesn't exist. ValueError: If `values_dict` cannot be parsed. """ for name, value in values_dict.items(): @@ -596,7 +601,7 @@ class HParams(object): sort_keys=sort_keys) def parse_json(self, values_json): - """Override hyperparameter values, parsing new values from a json object. + """Override existing hyperparameter values, parsing new values from a json object. Args: values_json: String containing a json object of name:value pairs. @@ -605,6 +610,7 @@ class HParams(object): The `HParams` instance. Raises: + KeyError: If a hyperparameter in `values_json` doesn't exist. ValueError: If `values_json` cannot be parsed. """ values_map = json.loads(values_json) diff --git a/tensorflow/contrib/training/python/training/training.py b/tensorflow/contrib/training/python/training/training.py index c272a2ac144068cfb7355c2647eebf5bd0ce9d50..fc6e38ab4a5243cb7502f4ca42db03cbfd342a40 100644 --- a/tensorflow/contrib/training/python/training/training.py +++ b/tensorflow/contrib/training/python/training/training.py @@ -419,7 +419,7 @@ def create_train_op(total_loss, update_ops = set(update_ops) if not global_update_ops.issubset(update_ops): logging.warning('update_ops in create_train_op does not contain all the ' - ' update_ops in GraphKeys.UPDATE_OPS') + 'update_ops in GraphKeys.UPDATE_OPS') # Make sure update_ops are computed before total_loss. if update_ops: diff --git a/tensorflow/contrib/util/BUILD b/tensorflow/contrib/util/BUILD index d9ccda8e89a4c9a1b3f3d24915b9ad3fb4d9be5f..07dbd5ca8d65ec8232d33c016a7369c68a4c9e1f 100644 --- a/tensorflow/contrib/util/BUILD +++ b/tensorflow/contrib/util/BUILD @@ -16,9 +16,12 @@ cc_library( srcs = ["convert_graphdef_memmapped_format_lib.cc"], hdrs = ["convert_graphdef_memmapped_format_lib.h"], deps = [ + "//tensorflow/core:array_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", + "//tensorflow/core:nn_ops_op_lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", "//tensorflow/core/kernels:immutable_constant_op", diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 8bf1480d33b2d2117fb5c7ddf046262cfeb8a8ab..a932974270f5dc00ba61b1f6e57ee7b00778039c 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -70,6 +70,9 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 +# Export the BUILD file so automated tooling can check licenses +exports_files(["BUILD"]) + load( "//tensorflow:tensorflow.bzl", "cc_header_only_library", @@ -178,7 +181,6 @@ COMMON_PROTO_SRCS = [ "framework/function.proto", "framework/graph.proto", "framework/graph_transfer_info.proto", - "framework/iterator.proto", "framework/kernel_def.proto", "framework/log_memory.proto", "framework/node_def.proto", @@ -203,6 +205,7 @@ COMMON_PROTO_SRCS = [ "protobuf/rewriter_config.proto", "protobuf/tensor_bundle.proto", "protobuf/saver.proto", + "protobuf/verifier_config.proto", "util/event.proto", "util/memmapped_file_system.proto", "util/saved_tensor_slice.proto", @@ -445,7 +448,8 @@ cc_library( ) cc_library( - name = "logger_interface", + name = "logger", + srcs = ["platform/logger.cc"], hdrs = ["platform/logger.h"], copts = tf_copts(), visibility = ["//visibility:public"], @@ -455,23 +459,6 @@ cc_library( ], ) -cc_library( - name = "default_logger", - srcs = ["platform/default/logger.cc"], - hdrs = ["platform/logger.h"], - deps = [ - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:logger_interface", - ], -) - -cc_library( - name = "logger", - hdrs = ["platform/logger.h"], - visibility = ["//visibility:public"], - deps = ["//tensorflow/core/platform/default/build_config:logger"], -) - filegroup( name = "platform_env_hdrs", srcs = [ @@ -520,6 +507,7 @@ cc_library( ":platform_port", ":platform_protobuf", "//tensorflow/core/platform/default/build_config:env", + "//tensorflow/core/platform/default/build_config:port", ], ) @@ -1033,6 +1021,7 @@ cc_library( ":lib", ":lib_internal", ":protos_all_cc", + "//tensorflow/core/util/proto:proto_utils", ], ) @@ -1090,6 +1079,7 @@ tf_gen_op_libs( "tensor_forest_ops", "candidate_sampling_ops", "checkpoint_ops", + "clustering_ops", "collective_ops", "control_flow_ops", "ctc_ops", @@ -1115,6 +1105,7 @@ tf_gen_op_libs( "parsing_ops", "random_grad", "random_ops", + "stateful_random_ops", "remote_fused_graph_ops", "rpc_ops", "scoped_allocator_ops", @@ -1244,6 +1235,7 @@ cc_library( ":tensor_forest_ops_op_lib", ":candidate_sampling_ops_op_lib", ":checkpoint_ops_op_lib", + ":clustering_ops_op_lib", ":collective_ops_op_lib", ":control_flow_ops_op_lib", ":ctc_ops_op_lib", @@ -1269,6 +1261,7 @@ cc_library( ":parsing_ops_op_lib", ":ragged_ops", ":random_ops_op_lib", + ":stateful_random_ops_op_lib", ":remote_fused_graph_ops_op_lib", ":resource_variable_ops_op_lib", ":rpc_ops_op_lib", @@ -1387,7 +1380,7 @@ cc_library( # This includes implementations of all kernels built into TensorFlow. cc_library( - name = "all_kernels_statically_linked", + name = "all_kernels_impl", visibility = ["//visibility:private"], deps = [ "//tensorflow/core/kernels:array", @@ -1398,12 +1391,12 @@ cc_library( "//tensorflow/core/kernels:tensor_forest_ops", "//tensorflow/core/kernels:candidate_sampler_ops", "//tensorflow/core/kernels:checkpoint_ops", + "//tensorflow/core/kernels:clustering_ops", "//tensorflow/core/kernels:collective_ops", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:ctc_ops", "//tensorflow/core/kernels:cudnn_rnn_kernels", "//tensorflow/core/kernels:data_flow", - "//tensorflow/core/kernels:dataset_ops", "//tensorflow/core/kernels:decode_proto_op", "//tensorflow/core/kernels:encode_proto_op", "//tensorflow/core/kernels:fake_quant_ops", @@ -1414,18 +1407,20 @@ cc_library( "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", "//tensorflow/core/kernels:linalg", - "//tensorflow/core/kernels:list_kernels", "//tensorflow/core/kernels:lookup", "//tensorflow/core/kernels:logging", "//tensorflow/core/kernels:manip", "//tensorflow/core/kernels:math", "//tensorflow/core/kernels:multinomial_op", + "//tensorflow/core/kernels:mutex_ops", "//tensorflow/core/kernels:nn", "//tensorflow/core/kernels:parameterized_truncated_normal_op", "//tensorflow/core/kernels:parsing", "//tensorflow/core/kernels:partitioned_function_ops", + "//tensorflow/core/kernels:pooling_ops", "//tensorflow/core/kernels:ragged_ops", "//tensorflow/core/kernels:random_ops", + "//tensorflow/core/kernels:stateful_random_ops", "//tensorflow/core/kernels:random_poisson_op", "//tensorflow/core/kernels:remote_fused_graph_ops", "//tensorflow/core/kernels:required", @@ -1477,8 +1472,13 @@ cc_library( visibility = ["//visibility:public"], deps = if_dynamic_kernels( [], - otherwise = [":all_kernels_statically_linked"], - ), + otherwise = [":all_kernels_impl"], + ) + [ + # TODO(gunan): Work on the API between these and rest of TF and make + # these also dynamically loading. + "//tensorflow/core/kernels:dataset_ops", # Depends on grappler + "//tensorflow/core/kernels:list_kernels", # Depends on variant_op_registry.h + ], ) tf_cuda_library( @@ -1763,6 +1763,7 @@ cc_library( cc_library( name = "mobile_additional_lib_deps", deps = tf_additional_lib_deps() + [ + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], @@ -1963,6 +1964,14 @@ cc_library( ], ) +cc_library( + name = "rocm", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core/platform/default/build_config:rocm", + ], +) + # ----------------------------------------------------------------------------- # Clif-related proto libraries. @@ -2022,6 +2031,13 @@ tf_pyclif_proto_library( visibility = ["//visibility:public"], ) +tf_pyclif_proto_library( + name = "framework/step_stats_pyclif", + proto_lib = ":protos_all_cc", + proto_srcfile = "framework/step_stats.proto", + visibility = ["//visibility:public"], +) + tf_pyclif_proto_library( name = "framework/types_pyclif", proto_lib = ":protos_all_cc", @@ -2199,6 +2215,7 @@ cc_library( ], }), deps = tf_additional_lib_deps() + [ + "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "//third_party/eigen3", "@com_google_absl//absl/base:core_headers", @@ -2213,7 +2230,6 @@ cc_library( "lib/**/*.cc", "platform/*.cc", "platform/profile_utils/**/*.cc", - ] + [ "framework/resource_handle.cc", "util/env_var.cc", ], @@ -2353,7 +2369,12 @@ cc_library( cc_library( name = "tflite_portable_logging", - srcs = [], + srcs = [ + ] + if_ios([ + "platform/default/logging.cc", + "platform/env_time.cc", + "platform/posix/env_time.cc", + ]), hdrs = [ "lib/bfloat16/bfloat16.h", "platform/default/integral_types.h", @@ -2362,7 +2383,7 @@ cc_library( "platform/macros.h", "platform/platform.h", "platform/types.h", - ] + if_windows(["platform/windows/integral_types.h"]), + ] + if_windows(["platform/windows/integral_types.h"]) + if_ios(["platform/env_time.h"]), copts = tf_copts(), linkopts = ["-ldl"], deps = [ @@ -2772,6 +2793,7 @@ cc_library( # in this library. GRAPH_HDRS = [ "graph/algorithm.h", + "graph/collective_order.h", "graph/colors.h", "graph/control_flow.h", "graph/costmodel.h", @@ -2798,6 +2820,7 @@ tf_cuda_library( name = "graph", srcs = [ "graph/algorithm.cc", + "graph/collective_order.cc", "graph/colors.cc", "graph/control_flow.cc", "graph/costmodel.cc", @@ -2815,6 +2838,9 @@ tf_cuda_library( ":proto_text", ":protos_all_cc", "//third_party/eigen3", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", ], ) @@ -2829,12 +2855,16 @@ CORE_CPU_BASE_HDRS = GRAPH_HDRS + [ "framework/versions.h", "common_runtime/process_function_library_runtime.h", "common_runtime/function.h", + "common_runtime/scoped_allocator.h", + "common_runtime/scoped_allocator_mgr.h", ] tf_cuda_library( name = "core_cpu_base", srcs = [ "common_runtime/eval_const_tensor.cc", + "common_runtime/scoped_allocator.cc", + "common_runtime/scoped_allocator_mgr.cc", "common_runtime/shape_refiner.cc", "common_runtime/shape_refiner.h", "framework/versions.h", @@ -2894,6 +2924,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/mkl_cpu_allocator.h", "common_runtime/optimization_registry.h", "common_runtime/pending_counts.h", + "common_runtime/partitioning_utils.h", "common_runtime/placer.h", "common_runtime/process_util.h", "common_runtime/profile_handler.h", @@ -2901,8 +2932,6 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/rendezvous_mgr.h", "common_runtime/rendezvous_util.h", "common_runtime/ring_reducer.h", - "common_runtime/scoped_allocator.h", - "common_runtime/scoped_allocator_mgr.h", "common_runtime/session_factory.h", "common_runtime/single_threaded_cpu_device.h", "common_runtime/stats_publisher_interface.h", @@ -2950,6 +2979,7 @@ tf_cuda_library( "common_runtime/mkl_cpu_allocator.cc", "common_runtime/optimization_registry.cc", "common_runtime/parallel_concat_optimizer.cc", + "common_runtime/partitioning_utils.cc", "common_runtime/placer.cc", "common_runtime/pool_allocator.cc", "common_runtime/process_function_library_runtime.cc", @@ -2959,8 +2989,6 @@ tf_cuda_library( "common_runtime/rendezvous_mgr.cc", "common_runtime/rendezvous_util.cc", "common_runtime/ring_reducer.cc", - "common_runtime/scoped_allocator.cc", - "common_runtime/scoped_allocator_mgr.cc", "common_runtime/session.cc", "common_runtime/session_factory.cc", "common_runtime/session_options.cc", @@ -2988,8 +3016,9 @@ tf_cuda_library( ":proto_text", ":protos_all_cc", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "//third_party/eigen3", - "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/utils:functions", ] + mkl_deps(), alwayslink = 1, ) @@ -3017,6 +3046,7 @@ tf_cuda_library( ":framework", ":graph", ":lib", + ":metrics", ":proto_text", ":protos_all_cc", "//tensorflow/core/grappler:grappler_item", @@ -3504,6 +3534,29 @@ tf_cc_test( ], ) +tf_cc_test( + name = "platform_fake_python_env_test", + size = "small", + srcs = ["platform/fake_python_env_test.cc"], + args = [ + "/some/path/to/pythontest.runfiles/org_tensorflow/stuff/to/run.py", + ], + tags = [ + "local", + "no_windows", + "nogpu", + "nomac", + "notap", + ], + deps = [ + ":lib", + ":lib_internal", + ":lib_test_internal", + ":test", + ":test_main", + ], +) + tf_cc_test( name = "platform_abi_test", size = "small", @@ -3677,7 +3730,6 @@ tf_cc_tests( srcs = [ "common_runtime/buf_rendezvous_test.cc", "common_runtime/collective_executor_mgr_test.cc", - "common_runtime/collective_param_resolver_local_test.cc", "common_runtime/collective_rma_local_test.cc", "common_runtime/device_resolver_local_test.cc", "common_runtime/device_set_test.cc", @@ -3793,6 +3845,7 @@ tf_cc_tests( name = "higher_level_tests_needing_kernels", size = "small", srcs = [ + "common_runtime/collective_param_resolver_local_test.cc", "graph/graph_constructor_test.cc", ], linkopts = select({ @@ -3842,6 +3895,27 @@ tf_cc_test( ], ) +tf_cc_tests( + name = "collective_order_test", + size = "small", + srcs = [ + "graph/collective_order_test.cc", + ], + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":ops", + ":protos_all_cc", + ":test", + "@com_google_googletest//:gtest_main", + ], +) + tf_cc_tests_gpu( name = "ring_reducer_test", size = "medium", @@ -4191,7 +4265,7 @@ tf_cc_test( ], ) -tf_cc_test( +tf_cuda_cc_test( name = "common_runtime_process_function_library_runtime_test", size = "small", srcs = ["common_runtime/process_function_library_runtime_test.cc"], @@ -4200,6 +4274,7 @@ tf_cc_test( ":core_cpu", ":core_cpu_internal", ":framework", + ":framework_internal", ":lib", ":test", ":test_main", @@ -4208,6 +4283,7 @@ tf_cc_test( "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:resource_variable_ops", ], ) @@ -4249,6 +4325,27 @@ tf_cc_test( ], ) +tf_cc_test( + name = "common_runtime_partitioning_utils_test", + size = "small", + srcs = ["common_runtime/partitioning_utils_test.cc"], + deps = [ + ":core_cpu", + ":core_cpu_internal", + ":framework", + ":lib", + ":ops", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:identity_op", + ], +) + tf_cuda_cc_test( name = "common_runtime_direct_session_test", size = "small", diff --git a/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt index 070d6adb978e4a62e7209f299dba08515aa21e83..d0794de4ba4a174838547865e4f1692cff503052 100644 --- a/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt @@ -33,6 +33,15 @@ END name: "padding" description: <