diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3dad41a88c8212b7445c32f241d887306d3c19ad..8669c25c452b53da48239bc20c9a2d3528e75422 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,5 +1,16 @@ # Contributing guidelines +## Pull Request Checklist + +Before sending your pull requests, make sure you followed this list. + +- Read [contributing guidelines](CONTRIBUTING.md). +- Read [Code of Conduct](CODE_OF_CONDUCT.md). +- Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/). +- Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). +- Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style). +- Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests). + ## How to become a contributor and submit your own code ### Contributor License Agreements diff --git a/README.md b/README.md index e1a50c87e26d493ba3ac760f357905d89aa40dab..6fb4486d0de9ff476b5cf1dbd63d66879637df84 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,9 @@ ----------------- -| **`Documentation`** | **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** | -|-----------------|---------------------|------------------|-------------------|---------------|---------------| -| [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | ![Build Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | ![Build Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-cc.png) | ![Build Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) +| **`Documentation`** | +|-----------------| +| [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | **TensorFlow** is an open source software library for numerical computation using data flow graphs. The graph nodes represent mathematical operations, while @@ -40,15 +40,6 @@ environment to install the nightly TensorFlow build. We support CPU and GPU packages on Linux, Mac, and Windows. -**Individual whl files** -* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/)) -* Linux GPU: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/42/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/)) -* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/)) -* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/)) -* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/)) -* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) -([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/)) - #### *Try your first TensorFlow program* ```shell $ python @@ -82,6 +73,30 @@ The TensorFlow project strives to abide by generally accepted best practices in [![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486) + +## Continuous build status + +### Official Builds + +| Build Type | Status | Artifacts | +| --- | --- | --- | +| **Linux CPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Linux GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-cc.png) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | +| **Linux XLA** | TBA | TBA | +| **MacOS** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Windows CPU** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Windows GPU** | [![Status](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/badge/icon)](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | +| **Android** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) [build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/) | + + +### Community Supported Builds + +| Build Type | Status | Artifacts | +| --- | --- | --- | +| **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA | +| **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA | + + ## For more information * [TensorFlow Website](https://www.tensorflow.org) diff --git a/SECURITY.md b/SECURITY.md index a5ce3a62ee202f6e7d83f0fedc2777d9c88ba9b5..01886b613e5d93793953124331b57f075fe7a373 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -173,7 +173,7 @@ the progress being made towards a fix and announcement. In addition, please include the following information along with your report: * Your name and affiliation (if any). -* A description the technical details of the vulnerabilities. It is very +* A description of the technical details of the vulnerabilities. It is very important to let us know how we can reproduce your findings. * An explanation who can exploit this vulnerability, and what they gain when doing so -- write an attack scenario. This will help us evaluate your report diff --git a/configure.py b/configure.py index fe15bfc1a43bac5d9c249bf5b61854ff0e07aec7..b6c32543cf707983d48e390cc89abf13dafd55d3 100644 --- a/configure.py +++ b/configure.py @@ -498,10 +498,6 @@ def set_cc_opt_flags(environ_cp): if not is_ppc64le() and not is_windows(): write_to_bazelrc('build:opt --host_copt=-march=native') write_to_bazelrc('build:opt --define with_default_optimizations=true') - # TODO(mikecase): Remove these default defines once we are able to get - # TF Lite targets building without them. - write_to_bazelrc('build --copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') - write_to_bazelrc('build --host_copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') def set_tf_cuda_clang(environ_cp): """set TF_CUDA_CLANG action_env. @@ -845,8 +841,8 @@ def reformat_version_sequence(version_str, sequence_count): def set_tf_cuda_version(environ_cp): """Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION.""" ask_cuda_version = ( - 'Please specify the CUDA SDK version you want to use, ' - 'e.g. 7.0. [Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION + 'Please specify the CUDA SDK version you want to use. ' + '[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): # Configure the Cuda SDK version to use. diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 82dbd3cdbc6e8fb0c6fbcddb33b6a95c87a83225..95b04f9058afdfaadbc24f0238860279fcd3e800 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -8407,3 +8407,51 @@ TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id, } return ret; } + +void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id, + TF_Tensor* tensor, TF_Status* status) { + assert(session); + { + tensorflow::mutex_lock c(session->graph->mu); + if (VLOG_IS_ON(1)) { + VLOG(1) << "Enqueuing named tensor with id " << tensor_id + << ", with input graph: " + << session->graph->graph.ToGraphDefDebug().DebugString(); + tensorflow::Tensor internal_tensor; + if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) { + VLOG(1) << "Enqueu'ing tensor content: " + << internal_tensor.DebugString(); + } + } + } + + TF_Operation* enqueue_op = TF_GraphOperationByName( + session->graph, + tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str()); + if (enqueue_op == nullptr) { + status->status = tensorflow::errors::Internal( + "Unable to find the enqueue node in the TF graph."); + return; + } + + TF_Operation* placeholder_op = TF_GraphOperationByName( + session->graph, + tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str()); + if (placeholder_op == nullptr) { + status->status = tensorflow::errors::Internal( + "Unable to find the placeholder node as input to enqueue in the TF " + "graph."); + return; + } + + VLOG(1) << "Running the enqueue op"; + TF_Output input{placeholder_op, 0}; + TF_SessionRun(session, /*run_options*/ nullptr, + // input related parameters + /*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1, + // output related parameters + /*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0, + /*targets*/ &enqueue_op, /*ntargets*/ 1, + /*run_metadata*/ nullptr, status); + VLOG(1) << "Enqueuing is done."; +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index e6757c065fc540fa789cdbb694e66ca0b00c4832..20bdace40f1272ded06e710034053a7610326e7f 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -87,8 +87,11 @@ TF_CAPI_EXPORT extern TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets( unsigned char is_mnist, TF_Status* status); // On success, dequeues a tensor from a TF-managed FifoQueue given by -// `tensor_id`, associated with `session`. Caller must call TF_DeleteTensor() -// over the returned tensor. If the queue is empty, this call is blocked. +// `tensor_id`, associated with `session`. There must be a graph node named +// "fifo_queue_dequeue_", to be executed by this API call. + +// Caller must call TF_DeleteTensor() over the returned tensor. If the queue is +// empty, this call is blocked. // // Tensors are enqueued via the corresponding TF enqueue op. // TODO(hongm): Add support for `timeout_ms`. @@ -96,6 +99,22 @@ TF_CAPI_EXPORT extern TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id, TF_Status* status); +// On success, enqueues `tensor` into a TF-managed FifoQueue given by +// `tensor_id`, associated with `session`. There must be a graph node named +// "fifo_queue_enqueue_", to be executed by this API call. It reads +// from a placeholder node "arg_tensor_enqueue_". +// +// `tensor` is still owned by the caller. This call will be blocked if the queue +// has reached its capacity, and will be unblocked when the queued tensors again +// drop below the capacity due to dequeuing. +// +// Tensors are dequeued via the corresponding TF dequeue op. +// TODO(hongm): Add support for `timeout_ms`. +TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session, + int tensor_id, + TF_Tensor* tensor, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index cd19cf8d624d9b914b61132f93d918b046cdbd30..c16aba666ee6974fed5351c2d9ac291dcbcdecab 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 14321191625e448637aa44a7f6a17820159b97c2..9ce781fab0be709fb0f115a2206eea4c2826bf36 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -24,10 +24,10 @@ tf_cuda_library( "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ - ":runtime", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu", + "//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:execute", @@ -49,6 +49,17 @@ tf_cuda_library( "//conditions:default": [], }) + [ "//tensorflow/core/common_runtime/eager:eager_operation", + "//tensorflow/core/distributed_runtime/eager:eager_client", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", + "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", + "//tensorflow/core/distributed_runtime:remote_device", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core:gpu_runtime", ], ) @@ -59,7 +70,6 @@ tf_cuda_library( visibility = ["//tensorflow:internal"], deps = [ ":c_api", - ":runtime", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu", @@ -69,11 +79,23 @@ tf_cuda_library( "//tensorflow/core:framework_lite", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/common_runtime/eager:kernel_and_device", "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/distributed_runtime:remote_device", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime/eager:eager_client", + "//tensorflow/core/distributed_runtime/eager:remote_tensor_handle", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", + "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", + "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", ], ) @@ -92,47 +114,7 @@ tf_cuda_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", - ], -) - -tf_cuda_library( - name = "runtime", - srcs = ["runtime.cc"], - hdrs = ["runtime.h"], - copts = tf_copts(), - visibility = ["//tensorflow:internal"], - deps = select({ - "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", - ], - "//conditions:default": [ - "//tensorflow/c:c_api", - "//tensorflow/core:core_cpu", - "//tensorflow/core/common_runtime/eager:kernel_and_device", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - ], - }), -) - -tf_cc_test( - name = "runtime_test", - srcs = ["runtime_test.cc"], - deps = [ - ":runtime", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:client_session", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", + "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 3bf071f3abaac7dfd4113964fd49cd9322913bd5..216210c88c1593ebc68f604547ab06b543a7b2af 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" -#include "tensorflow/c/eager/runtime.h" #ifdef TENSORFLOW_EAGER_USE_XLA #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #endif // TENSORFLOW_EAGER_USE_XLA @@ -32,15 +31,22 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h" #include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" @@ -71,6 +77,121 @@ string DeviceName(const tensorflow::Device* d) { std::atomic_int_fast64_t func_id_generator(0); #endif // TENSORFLOW_EAGER_USE_XLA +tensorflow::Status GetAllRemoteDevices( + const std::vector& remote_workers, + tensorflow::WorkerCacheInterface* worker_cache, + std::unique_ptr* device_mgr) { + std::vector remote_devices; + tensorflow::Status status; + // TODO(nareshmodi) do this in parallel instead of serially. + for (const string& remote_worker : remote_workers) { + tensorflow::Notification n; + tensorflow::NewRemoteDevices( + tensorflow::Env::Default(), worker_cache, remote_worker, + [&status, &n, &remote_devices]( + const tensorflow::Status& s, + std::vector* devices) { + status = s; + if (s.ok()) { + for (tensorflow::Device* d : *devices) { + remote_devices.push_back(d); + } + } + n.Notify(); + }); + n.WaitForNotification(); + } + std::unique_ptr remote_device_mgr( + new tensorflow::DeviceMgr(remote_devices)); + + TF_RETURN_IF_ERROR(status); + + *device_mgr = std::move(remote_device_mgr); + return tensorflow::Status::OK(); +} + +tensorflow::Status CreateRemoteContexts( + const std::vector& remote_workers, + tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, + tensorflow::gtl::FlatMap* remote_contexts) { + for (int i = 0; i < remote_workers.size(); i++) { + const string& remote_worker = remote_workers[i]; + + tensorflow::eager::CreateContextRequest request; + tensorflow::eager::CreateContextResponse response; + tensorflow::DeviceNameUtils::ParsedName parsed_name; + if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker, + &parsed_name)) { + return tensorflow::errors::InvalidArgument( + "Unable to parse ", remote_worker, " as a device name"); + } + request.mutable_server_def()->set_job_name(parsed_name.job); + request.mutable_server_def()->set_task_index(parsed_name.task); + request.set_async(async); + auto* eager_client = remote_eager_workers->GetClient(remote_worker); + if (eager_client == nullptr) { + return tensorflow::errors::Internal( + "Cannot find a client for the given target:", remote_worker); + } + tensorflow::Notification n; + tensorflow::Status status; + // TODO(nareshmodi) do this in parallel instead of serially. + eager_client->CreateContextAsync( + &request, &response, [&status, &n](const tensorflow::Status& s) { + status = s; + n.Notify(); + }); + n.WaitForNotification(); + TF_RETURN_IF_ERROR(status); + + remote_contexts->emplace(remote_worker, response.context_id()); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, + TFE_Context** ctx) { + string worker_name = tensorflow::strings::StrCat( + "/job:", opts->server_def.job_name(), + "/replica:0/task:", opts->server_def.task_index()); + std::unique_ptr server; + TF_RETURN_IF_ERROR( + tensorflow::eager::EagerGrpcServer::Create(opts->server_def, &server)); + + TF_RETURN_IF_ERROR(server->Start()); + + std::vector remote_workers; + server->master_env()->worker_cache->ListWorkers(&remote_workers); + remote_workers.erase( + std::remove(remote_workers.begin(), remote_workers.end(), worker_name), + remote_workers.end()); + + std::unique_ptr remote_device_mgr; + TF_RETURN_IF_ERROR(GetAllRemoteDevices( + remote_workers, server->master_env()->worker_cache, &remote_device_mgr)); + + std::shared_ptr channel_cache = + server->channel_cache(); + std::unique_ptr remote_eager_workers( + tensorflow::eager::NewGrpcEagerClientCache(channel_cache)); + + // Initialize remote eager workers. + tensorflow::gtl::FlatMap remote_contexts; + TF_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers, + remote_eager_workers.get(), + opts->async, &remote_contexts)); + + tensorflow::RemoteRendezvous* r = + server->worker_env()->rendezvous_mgr->Find(0); + + auto* device_mgr = server->worker_env()->device_mgr; + *ctx = new TFE_Context(opts->session_options.options, opts->policy, + opts->async, device_mgr, r, std::move(server), + std::move(remote_eager_workers), + std::move(remote_device_mgr), remote_contexts); + + return tensorflow::Status::OK(); +} } // namespace extern "C" { @@ -91,6 +212,15 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( options->policy = policy; } +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( + TFE_ContextOptions* options, const void* proto, size_t proto_len, + TF_Status* status) { + if (!options->server_def.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Invalid tensorflow.ServerDef protocol buffer"); + } +} + TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, unsigned char async, TF_Status* status) { @@ -100,17 +230,23 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { + if (!opts->server_def.job_name().empty()) { + TFE_Context* ctx = nullptr; + status->status = NewRemoteAwareTFE_Context(opts, &ctx); + return ctx; + } + std::vector devices; status->status = tensorflow::DeviceFactory::AddDevices( opts->session_options.options, "/job:localhost/replica:0/task:0", &devices); - if (!status->status.ok()) { - return nullptr; - } + if (!status->status.ok()) return nullptr; std::unique_ptr device_mgr( new tensorflow::DeviceMgr(devices)); + tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr.get()); + return new TFE_Context(opts->session_options.options, opts->policy, opts->async, std::move(device_mgr), r); } @@ -119,7 +255,10 @@ void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { delete ctx; } TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* list = new TF_DeviceList; - ctx->context.device_mgr()->ListDeviceAttributes(&list->response); + ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response); + if (ctx->context.remote_device_mgr()) { + ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response); + } return list; } diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index c06ce84a8c578aa60dd626c24bd58098b78ae750..574a097e0d6f5d6e7acd77cae246678b6675129b 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -81,6 +81,16 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*, TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); +// A tensorflow.ServerDef specifies remote workers (in addition to the current +// workers name). Operations created on this context can then be executed on +// any of these remote workers by setting an appropriate device. +// +// If the following is set, all servers identified by the +// ServerDef must be up when the context is created. +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( + TFE_ContextOptions* options, const void* proto, size_t proto_len, + TF_Status* status); + // Destroy an options object. TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 49e1aab1cef9577256d9b081858cf094c788caf8..2b8384d72038c3b4a050c70b3e7c5e0ca0bd94f3 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -28,8 +28,8 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" -#include "tensorflow/c/eager/runtime.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" @@ -37,6 +37,14 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/eager/eager_client.h" +#include "tensorflow/core/distributed_runtime/remote_device.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" +#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -51,6 +59,7 @@ struct TFE_ContextOptions { // true if async execution is enabled. bool async = false; TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT}; + tensorflow::ServerDef server_def; }; struct TFE_Context { @@ -64,6 +73,23 @@ struct TFE_Context { default_policy), async, std::move(device_mgr), rendezvous) {} + explicit TFE_Context( + const tensorflow::SessionOptions& opts, + TFE_ContextDevicePlacementPolicy default_policy, bool async, + tensorflow::DeviceMgr* local_device_mgr, + tensorflow::Rendezvous* rendezvous, + std::unique_ptr server, + std::unique_ptr remote_eager_workers, + std::unique_ptr remote_device_mgr, + const tensorflow::gtl::FlatMap& + remote_contexts) + : context(opts, + static_cast( + default_policy), + async, local_device_mgr, rendezvous, std::move(server), + std::move(remote_eager_workers), std::move(remote_device_mgr), + remote_contexts) {} + tensorflow::EagerContext context; }; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 701175e4943d1d23532fe595319f67711316ed4d..49646bb73599d96fce2df90f918e692df7972aeb 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include +#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -23,7 +24,9 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" using tensorflow::string; @@ -220,6 +223,103 @@ TEST(CAPI, Context) { TF_DeleteStatus(status); } +tensorflow::ServerDef GetServerDef(int num_tasks) { + tensorflow::ServerDef server_def; + server_def.set_protocol("grpc"); + server_def.set_job_name("localhost"); + server_def.set_task_index(0); + tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); + tensorflow::JobDef* job_def = cluster_def->add_job(); + job_def->set_name("localhost"); + for (int i = 0; i < num_tasks; i++) { + int port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->insert( + {i, tensorflow::strings::StrCat("localhost:", port)}); + } + return server_def; +} + +void TestRemoteExecute(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE( + tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), + status); + TFE_ContextOptionsSetAsync(opts, static_cast(1)); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + auto* h0_task1 = + TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + auto* h1_task1 = + TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + auto* retval_task0 = TFE_TensorHandleCopyToDevice( + retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(retval_task0); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(h1_task0); + TFE_DeleteTensorHandle(h0_task1); + TFE_DeleteTensorHandle(h1_task1); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(matmul); + + TFE_ContextAsyncWait(ctx, status); + TFE_DeleteContext(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); } +TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); } + TEST(CAPI, TensorHandle) { TFE_TensorHandle* h = TestMatrixTensorHandle(); EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 8026076b9ef3bf07f6f84caf80329a1301c6a27f..1833b25fea0047c9652318e49599ba623daaec26 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -104,6 +104,10 @@ class VSpace { gtl::ArraySlice output_gradients, std::vector* result) const = 0; + // Marks the following gradient as a result so it's not consumed by backward + // functions. + virtual void MarkAsResult(Gradient* gradient) const = 0; + // Deletes the input tensor. virtual void DeleteGradient(Gradient* gradient) const = 0; @@ -130,13 +134,15 @@ class GradientTape { } } - bool ShouldRecord(gtl::ArraySlice tensor_ids); + bool ShouldRecord(gtl::ArraySlice tensor_ids, + gtl::ArraySlice dtypes); void Watch(int64 tensor_id); void RecordOperation(const string& op_type, gtl::ArraySlice output_tensors, gtl::ArraySlice input_tensor_id, + gtl::ArraySlice input_dtypes, BackwardFunction* backward_function, const std::function& backward_function_deleter); @@ -170,12 +176,32 @@ class GradientTape { // Template instantiations here +inline bool IsDtypeTrainable(DataType dtype) { + switch (dtype) { + case DT_HALF: + case DT_BFLOAT16: + case DT_FLOAT: + case DT_DOUBLE: + case DT_COMPLEX64: + case DT_COMPLEX128: + case DT_RESOURCE: + case DT_VARIANT: + return true; + default: + return false; + } +} + template bool GradientTape::ShouldRecord( - gtl::ArraySlice tensor_ids) { - for (int64 i : tensor_ids) { - if (tensor_tape_.find(i) != tensor_tape_.end()) { - return true; + gtl::ArraySlice tensor_ids, + gtl::ArraySlice dtypes) { + CHECK_EQ(tensor_ids.size(), dtypes.size()); + for (int i = 0; i < tensor_ids.size(); ++i) { + if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) { + if (IsDtypeTrainable(dtypes[i])) { + return true; + } } } return false; @@ -189,9 +215,11 @@ void GradientTape::Watch(int64 tensor_id) { template void GradientTape::RecordOperation( const string& op_type, gtl::ArraySlice output_tensors, - gtl::ArraySlice input_tensor_id, BackwardFunction* backward_function, + gtl::ArraySlice input_tensor_id, + gtl::ArraySlice input_dtypes, + BackwardFunction* backward_function, const std::function& backward_function_deleter) { - if (!ShouldRecord(input_tensor_id)) { + if (!ShouldRecord(input_tensor_id, input_dtypes)) { backward_function_deleter(); return; } @@ -332,8 +360,7 @@ BackpropInitialState PrepareBackprop( count_it->second++; } else { result.tensor_usage_counts[it] = 1; - if (sources_set.find(it) == sources_set.end() && - tensor_tape.find(it) != tensor_tape.end()) { + if (tensor_tape.find(it) != tensor_tape.end()) { tensor_stack.push_back(it); } } @@ -498,10 +525,15 @@ Status GradientTape::ComputeGradient( } } else { any_gradient_nonzero = true; - out_gradients.push_back(vspace.AggregateGradients(grad_it->second)); + auto new_gradients = vspace.AggregateGradients(grad_it->second); if (sources_set.find(grad_it->first) == sources_set.end()) { gradients.erase(grad_it); + } else { + grad_it->second.clear(); + grad_it->second.push_back(new_gradients); + vspace.MarkAsResult(new_gradients); } + out_gradients.push_back(new_gradients); } } std::vector in_gradients; diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 1b4c7c2688083e74433da3dce2849b8c37443684..fd7b6fe6625f27bda92e2f56f60908658cdecd7e 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -31,7 +31,6 @@ using ops::AddN; using ops::BatchMatMul; using ops::Const; using ops::Div; -using ops::Greater; using ops::MatMul; using ops::Max; using ops::Maximum; @@ -46,7 +45,6 @@ using ops::RealDiv; using ops::SquaredDifference; using ops::Sub; using ops::Sum; -using ops::Where3; // TODO(andydavis) Test gradient function against numeric gradients output. // TODO(andydavis) As more gradients are added move common test functions diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 0cb3132e94e381f672d69aefe4a199d2b590830c..c73482d5f4d13ade0dc0412941251d1651371b6e 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -255,6 +255,53 @@ Status LRNGradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("LRN", LRNGradHelper); +Status SoftplusGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto dx = internal::SoftplusGrad(scope, grad_inputs[0], op.input(0)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Softplus", SoftplusGradHelper); + +Status SoftsignGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto dx = internal::SoftsignGrad(scope, grad_inputs[0], op.input(0)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Softsign", SoftsignGradHelper); + +Status FractionalAvgPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool overlapping; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping)); + auto dx = internal::FractionalAvgPoolGrad( + scope, Shape(scope, op.input(0), Shape::OutType(DT_INT64)), + grad_inputs[0], op.output(1), op.output(2), + internal::FractionalAvgPoolGrad::Overlapping(overlapping)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("FractionalAvgPool", FractionalAvgPoolGradHelper); + +Status FractionalMaxPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool overlapping; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping)); + auto dx = internal::FractionalMaxPoolGrad( + scope, op.input(0), op.output(0), grad_inputs[0], op.output(1), + op.output(2), internal::FractionalMaxPoolGrad::Overlapping(overlapping)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("FractionalMaxPool", FractionalMaxPoolGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index c4eba7ecb017fe4628140d75a63bc7f0f09deb7f..b4d457a9d14eb79232cda9412fa0050f6a9968cc 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -28,6 +28,8 @@ namespace { using ops::BiasAdd; using ops::Conv2D; using ops::Elu; +using ops::FractionalAvgPool; +using ops::FractionalMaxPool; using ops::L2Loss; using ops::LogSoftmax; using ops::LRN; @@ -41,6 +43,8 @@ using ops::Relu; using ops::Relu6; using ops::Selu; using ops::Softmax; +using ops::Softplus; +using ops::Softsign; class NNGradTest : public ::testing::Test { protected: @@ -71,22 +75,30 @@ class NNGradTest : public ::testing::Test { EXPECT_LT(max_error, 1e-3); } - // Sets tensor with random values, ensuring that the max value is largest by - // a reasonable amount. - // This is an issue for MaxPool, MaxPoolV2 and MaxPool3D, in which - // perturbations by the numeric gradient computation in the gradient checker - // can change the max value if values are too close together. + // Sets tensor with random values, ensuring that every pair of elements are at + // least a reasonable amount apart. + // This is an issue for max pooling operations, in which perturbations by the + // numeric gradient computation in the gradient checker can change the max + // value if a pool has values that are too close together. template - void SetRandomValuesWithBumpedMax(Tensor* tensor) { + void SetRandomValuesForMaxPooling(Tensor* tensor) { auto tensor_flat = tensor->flat(); - tensor_flat.setRandom(); - int32 max_index = 0; - for (size_t i = 1; i < tensor->NumElements(); i++) { - if (tensor_flat(i) > tensor_flat(max_index)) { - max_index = i; - } + // First set the array to an increasing sequence of values spaced + // a reasonable amount apart + T cur = 0; + for (size_t i = 0; i < tensor->NumElements(); i++) { + tensor_flat(i) = cur; + cur += 5e-2; + } + // Fischer-Yates shuffle the array + for (size_t i = tensor->NumElements() - 1; i >= 1; i--) { + // j <- random integer 0 <= j <= i + size_t j = random::New64() % (i + 1); + // swap values at i, j + T tmp = tensor_flat(i); + tensor_flat(i) = tensor_flat(j); + tensor_flat(j) = tmp; } - tensor_flat(max_index) += 1e-2; } Scope scope_; @@ -189,7 +201,7 @@ TEST_F(NNGradTest, MaxPoolGradHelper) { const std::vector strides{1, 2, 2, 1}; auto y = MaxPool(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -202,7 +214,7 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) { Tensor strides = test::AsTensor({1, 2, 2, 1}, {4}); auto y = MaxPoolV2(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -215,7 +227,7 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) { const std::vector strides{1, 3, 3, 3, 1}; auto y = MaxPool3D(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -248,5 +260,45 @@ TEST_F(NNGradTest, LRN){ RunTest(x, x_shape, y, x_shape); } +TEST_F(NNGradTest, SoftplusGrad) { + TensorShape shape({3, 7}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Softplus(scope_, x); + RunTest(x, shape, y, shape); +} + +TEST_F(NNGradTest, SoftsignGrad) { + TensorShape shape({3, 7}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Softsign(scope_, x); + RunTest(x, shape, y, shape); +} + +TEST_F(NNGradTest, FractionalAvgPoolGradHelper) { + TensorShape x_shape({1, 3, 7, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Force consistent pooling regions for unit testing. + auto y = FractionalAvgPool( + scope_, x, {1, 1.2, 1.9, 1}, + FractionalAvgPool::Deterministic(true).Overlapping(true).Seed(1).Seed2( + 2)); + TensorShape y_shape({1, 2, 3, 1}); + RunTest(x, x_shape, y.output, y_shape); +} + +TEST_F(NNGradTest, FractionalMaxPoolGradHelper) { + TensorShape x_shape({1, 3, 7, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Force consistent pooling regions for unit testing. + auto y = FractionalMaxPool( + scope_, x, {1, 1.2, 1.9, 1}, + FractionalMaxPool::Deterministic(true).Overlapping(true).Seed(1).Seed2( + 2)); + Tensor x_init_value = Tensor(DT_FLOAT, x_shape); + SetRandomValuesForMaxPooling(&x_init_value); + TensorShape y_shape({1, 2, 3, 1}); + RunTest(x, x_init_value, y.output, y_shape); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc index 4ddddcb5863c9ffb1e5367db750b0d2ffd29cd5e..23e9dc40d23899b9cef168c9128b6d8ed1be3ee9 100644 --- a/tensorflow/cc/tools/freeze_saved_model.cc +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/cc/tools/freeze_saved_model.h" +#include #include #include "tensorflow/core/framework/attr_value.pb.h" @@ -71,6 +72,15 @@ void GetNodeNameToNodeDefMap( } } +// Strips off the tensor part of the tensor_name to get the node_name. +const string GetNodeNameFromTensorName(string tensor_name) { + if (tensor_name[0] == '^') { + tensor_name.erase(0, 1); + } + std::vector tensor_name_parts = str_util::Split(tensor_name, ':'); + return tensor_name_parts[0]; +} + // Gets the set of node names needed by `outputs` and the corresponding set of // variable nodes to convert. void GetReachableNodesAndVariables( @@ -83,10 +93,8 @@ void GetReachableNodesAndVariables( new std::unordered_set({"Variable", "VariableV2", "VarHandleOp"}); std::queue nodes_to_visit; - for (const string& tensor_name : outputs) { - // We need to strip off the tensor part to get the node name. - std::vector tensor_name_parts = str_util::Split(tensor_name, ':'); - nodes_to_visit.push(tensor_name_parts[0]); + for (const string& output_tensor_name : outputs) { + nodes_to_visit.push(GetNodeNameFromTensorName(output_tensor_name)); } // We do a traversal backwards from the outputs specified in the MetaGraphDef. while (!nodes_to_visit.empty()) { @@ -100,8 +108,8 @@ void GetReachableNodesAndVariables( if (kVariableTypes->find(node->op()) != kVariableTypes->end()) { variable_node_names->insert(node->name()); } - for (const string& input : node->input()) { - nodes_to_visit.push(input); + for (const string& input_tensor_name : node->input()) { + nodes_to_visit.push(GetNodeNameFromTensorName(input_tensor_name)); } } } diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc index cd35fd3b95deec669218cfa4f25fea2c3ac9e56e..979b23c3fc5f66ec574736cb4d39cec0ffd8e6b6 100644 --- a/tensorflow/cc/tools/freeze_saved_model_test.cc +++ b/tensorflow/cc/tools/freeze_saved_model_test.cc @@ -351,6 +351,56 @@ TEST_F(FreezeTest, GraphDefWithNoVariables) { GraphDefEqual(frozen_graph_def, graph_def); } +TEST_F(FreezeTest, GraphDefWithMultiOutputOperation) { + // Tensors from operations with multiple outputs get tensor suffixes when used + // in input fields of following nodes, i.e. split:0, split:1. + // Test that we traverse those correctly. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), {10.0f, 10.0f}, {2}); + Output axis = ops::Const(scope.WithOpName("axis"), 0, {}); + OutputList split = ops::Split(scope.WithOpName("split"), axis, a, 2).output; + Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); + Output c = ops::Mul(scope.WithOpName("c"), split[1], b); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "", + &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + GraphDefEqual(frozen_graph_def, graph_def); +} + +TEST_F(FreezeTest, GraphDefWithControlDependency) { + // Inputs that are control dependencies get tensor prefixes, + // i.e. ^control_dependency. + // Test that we traverse those correctly. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output source = ops::Const(scope.WithOpName("source"), 10.0f, {}); + Output a = ops::Const(scope.WithOpName("a").WithControlDependencies(source), + {10.0f, 10.0f}, {2}); + Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); + Output c = ops::Mul(scope.WithOpName("c"), a, b); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "", + &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + GraphDefEqual(frozen_graph_def, graph_def); +} + TEST_F(FreezeTest, GraphDefWithoutDependentVariables) { TestFreezeGraphWithoutDependentVariables(false); } diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 19e6bf68e77725bb3cae4e1d338c52dff472cb18..2119c8ec47f941a76e81346ae5d20da78eae11a3 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -214,7 +214,6 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "@llvm//:core", - "@llvm//:execution_engine", "@llvm//:support", "@llvm//:target", ], diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 6e050cf56494e6d26e3647e3261a657eeaad64fa..6641d45e83020f4144616a6a2837c844330298f5 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -56,9 +56,9 @@ namespace bar { // // Memory stats: // arg bytes total: 104 -// arg bytes aligned: 128 +// arg bytes aligned: 192 // temp bytes total: 126 -// temp bytes aligned: 224 +// temp bytes aligned: 320 class MyClass : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index 63d22de1ca4aa0872b6fad3e0ac0182306d7cb8c..4e27aafec7747655d8e4ea3ddd1788d495ca0710 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -82,7 +82,8 @@ static StatusOr CodegenModule(llvm::TargetMachine* target_machine, llvm::legacy::PassManager codegen_passes; if (target_machine->addPassesToEmitFile( - codegen_passes, ostream, llvm::TargetMachine::CGFT_ObjectFile)) { + codegen_passes, ostream, nullptr, + llvm::TargetMachine::CGFT_ObjectFile)) { return xla::InternalError( "Could not create pass pipeline to generate object file"); } diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h index ebfe4806c203e901358d5c5096c10c03d4c738c3..4e194a6aba9a9efcad27c47c42e148d8e537ae68 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.h +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -71,7 +71,7 @@ struct ProtobufToEmbed { const ::tensorflow::protobuf::MessageLite* message; }; -// Embeds a a sequence of protocol buffers into an object file. +// Embeds a sequence of protocol buffers into an object file. // // `target_triple` is the target triple for the target architecture for the // generated object file. diff --git a/tensorflow/compiler/aot/runtime.h b/tensorflow/compiler/aot/runtime.h index d085864f0012e4de55685bb46961417bb3070e6f..d1a669ceb17b9fd71d26e978035283f8824b0376 100644 --- a/tensorflow/compiler/aot/runtime.h +++ b/tensorflow/compiler/aot/runtime.h @@ -25,8 +25,8 @@ namespace tensorflow { namespace tfcompile { namespace runtime { -// Align to 32-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. -static constexpr size_t kAlign = 32; +// Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. +static constexpr size_t kAlign = 64; // aligned_buffer_bytes returns the sum of each size in `sizes`, skipping -1 // values. There are `n` entries in `sizes`. Each buffer is aligned to kAlign diff --git a/tensorflow/compiler/aot/runtime_test.cc b/tensorflow/compiler/aot/runtime_test.cc index 6d603a02eb4ceade6832ba67b2981814ee25327a..06ec623eb2dce5f8dc7156fb7e7b9ad57d90c8ee 100644 --- a/tensorflow/compiler/aot/runtime_test.cc +++ b/tensorflow/compiler/aot/runtime_test.cc @@ -24,7 +24,7 @@ namespace runtime { namespace { TEST(Runtime, AlignmentValue) { - // We've chosen 32 byte alignment for the tfcompile runtime to mimic the + // We've chosen 64 byte alignment for the tfcompile runtime to mimic the // regular tensorflow allocator, which was chosen to play nicely with Eigen. // The tfcompile runtime also has a requirement that comes from the xla // generated code, on the relation: buffer_size >= 16 ? 2 * sizeof(void*) : 8 @@ -39,13 +39,13 @@ TEST(Runtime, AlignedBufferBytes) { EXPECT_EQ(aligned_buffer_bytes(sizesA, 1), 0); static constexpr intptr_t sizesB[1] = {3}; - EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 32); + EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 64); static constexpr intptr_t sizesC[1] = {32}; - EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 32); + EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 64); static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; - EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 192); + EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 320); } void* add_ptr(void* base, uintptr_t delta) { @@ -101,11 +101,11 @@ TEST(Runtime, MallocFreeContiguousBuffers) { EXPECT_NE(base, nullptr); EXPECT_EQ(bufD[0], add_ptr(base, 0)); EXPECT_EQ(bufD[1], nullptr); - EXPECT_EQ(bufD[2], add_ptr(base, 32)); + EXPECT_EQ(bufD[2], add_ptr(base, 64)); EXPECT_EQ(bufD[3], nullptr); - EXPECT_EQ(bufD[4], add_ptr(base, 64)); - EXPECT_EQ(bufD[5], add_ptr(base, 128)); - EXPECT_EQ(bufD[6], add_ptr(base, 160)); + EXPECT_EQ(bufD[4], add_ptr(base, 128)); + EXPECT_EQ(bufD[5], add_ptr(base, 192)); + EXPECT_EQ(bufD[6], add_ptr(base, 256)); for (int i = 0; i < 7; ++i) { const intptr_t size = sizesD[i]; if (size != -1) { diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 222e26810ac1157152ea81a56749b6652aa1f137..fd2cf2b67d4618dd626b8eef78eed044d7fde0a4 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -15,6 +15,7 @@ test_suite( ":test_graph_tfadd_with_ckpt_saver_test", ":test_graph_tfadd_with_ckpt_test", ":test_graph_tfassert_eq_test", + ":test_graph_tfcond_test", ":test_graph_tffunction_test", ":test_graph_tfgather_test", ":test_graph_tfmatmul_test", @@ -55,6 +56,7 @@ genrule( "test_graph_tfadd_with_ckpt_saver.pb", "test_graph_tfadd_with_ckpt_saver.saver", "test_graph_tfassert_eq.pb", + "test_graph_tfcond.pb", "test_graph_tffunction.pb", "test_graph_tfgather.pb", "test_graph_tfmatmul.pb", @@ -118,6 +120,17 @@ tf_library( ], ) +tf_library( + name = "test_graph_tfcond", + testonly = 1, + config = "test_graph_tfcond.config.pbtxt", + cpp_class = "CondComp", + graph = "test_graph_tfcond.pb", + tags = [ + "manual", + ], +) + tf_library( name = "test_graph_tffunction", testonly = 1, @@ -194,6 +207,7 @@ tf_cc_test( ":test_graph_tfadd_with_ckpt", ":test_graph_tfadd_with_ckpt_saver", ":test_graph_tfassert_eq", + ":test_graph_tfcond", ":test_graph_tffunction", ":test_graph_tfgather", ":test_graph_tfmatmul", diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 67767f55dae9b15aafbd8b129328bde2c59a9ef3..9ec7df163b1425f917e9ec51559efad3e6f05e75 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -78,6 +78,22 @@ def tfadd_with_ckpt_saver(out_dir): f.write(saver.as_saver_def().SerializeToString()) +def tfassert_eq(_): + x = array_ops.placeholder(dtypes.int32, name='x_hold') + y = array_ops.placeholder(dtypes.int32, name='y_hold') + control_flow_ops.Assert( + math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq') + math_ops.add(x, math_ops.negative(y), name='x_y_diff') + + +def tfcond(_): + p = array_ops.placeholder(dtypes.bool, name='p_hold') + x = array_ops.placeholder(dtypes.int32, name='x_hold') + y = array_ops.placeholder(dtypes.int32, name='y_hold') + z = control_flow_ops.cond(p, lambda: x, lambda: y) + array_ops.identity(z, name='result') + + def tfgather(_): params = array_ops.placeholder(dtypes.float32, name='params') indices = array_ops.placeholder(dtypes.int32, name='indices') @@ -126,14 +142,6 @@ def tfsplits(_): array_ops.identity(y, name='result') -def tfassert_eq(_): - x = array_ops.placeholder(dtypes.int32, name='x_hold') - y = array_ops.placeholder(dtypes.int32, name='y_hold') - control_flow_ops.Assert( - math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq') - math_ops.add(x, math_ops.negative(y), name='x_y_diff') - - def write_graph(build_graph, out_dir): """Build a graph using build_graph and write it out.""" g = ops.Graph() @@ -148,12 +156,13 @@ def main(_): write_graph(tfadd, FLAGS.out_dir) write_graph(tfadd_with_ckpt, FLAGS.out_dir) write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir) + write_graph(tfassert_eq, FLAGS.out_dir) + write_graph(tfcond, FLAGS.out_dir) + write_graph(tffunction, FLAGS.out_dir) write_graph(tfgather, FLAGS.out_dir) write_graph(tfmatmul, FLAGS.out_dir) write_graph(tfmatmulandadd, FLAGS.out_dir) - write_graph(tffunction, FLAGS.out_dir) write_graph(tfsplits, FLAGS.out_dir) - write_graph(tfassert_eq, FLAGS.out_dir) if __name__ == '__main__': diff --git a/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..94a01ad4abfaab5e4b087b7cc219e86c1d0179b8 --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt @@ -0,0 +1,20 @@ +# Text form of tensorflow.tf2xla.Config proto. +feed { + id { node_name: "p_hold" } + shape {} +} +feed { + id { node_name: "x_hold" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "y_hold" } + shape { + dim { size: 1 } + } +} +fetch { + id { node_name: "result" } +} diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 27ba42b31fc2504a570584a9881c032582731baf..fee46280e9a0e7ba2cf7c3ed46469ae8cc0841d4 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h" #include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfcond.h" #include "tensorflow/compiler/aot/tests/test_graph_tffunction.h" #include "tensorflow/compiler/aot/tests/test_graph_tfgather.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" @@ -39,7 +40,7 @@ namespace tfcompile { namespace { using ::testing::HasSubstr; -using ::testing::UnorderedElementsAre; +using ::testing::IsSupersetOf; TEST(TFCompileTest, Add) { AddComp add; @@ -150,6 +151,31 @@ TEST(TFCompileTest, AddWithCkptSaver) { EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); } +TEST(TFCompileTest, Cond) { + CondComp cond; + EXPECT_EQ(cond.arg0_data(), cond.args()[0]); + EXPECT_EQ(cond.arg1_data(), cond.args()[1]); + EXPECT_EQ(cond.arg2_data(), cond.args()[2]); + cond.arg1() = 10; + cond.arg2() = 20; + { + cond.arg0() = true; + const int32 expected_result = cond.arg1(); + EXPECT_TRUE(cond.Run()); + EXPECT_EQ(cond.result0(), expected_result); + EXPECT_EQ(cond.result0_data()[0], expected_result); + EXPECT_EQ(cond.result0_data(), cond.results()[0]); + } + { + cond.arg0() = false; + const int32 expected_result = cond.arg2(); + EXPECT_TRUE(cond.Run()); + EXPECT_EQ(cond.result0(), expected_result); + EXPECT_EQ(cond.result0_data()[0], expected_result); + EXPECT_EQ(cond.result0_data(), cond.results()[0]); + } +} + TEST(TFCompileTest, Gather) { GatherComp gather; EXPECT_EQ(gather.arg0_data(), gather.args()[0]); @@ -525,25 +551,20 @@ TEST(TFCompileTest, HloProfiling) { auto header = HasSubstr("Execution profile for"); auto total_cycles_profile_line = HasSubstr("[total]"); auto dot_profile_line = HasSubstr( - "%dot.0.2 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " + "%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " "%arg1.0.1)"); auto add_profile_line = HasSubstr( - "%add.0.5 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " + "%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " "%arg1.0.1)"); auto tuple_profile_line = HasSubstr( "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} " - "%dot.0.2, f32[2,2]{1,0} %add.0.5)"); + "%dot.0.4, f32[2,2]{1,0} %add.0.6)"); auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)"); auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)"); - hlo_profile_lines.erase(hlo_profile_lines.begin() + 7, - hlo_profile_lines.end()); - - EXPECT_THAT( - hlo_profile_lines, - UnorderedElementsAre(header, total_cycles_profile_line, dot_profile_line, - add_profile_line, tuple_profile_line, - arg0_profile_line, arg1_profile_line)); + EXPECT_THAT(hlo_profile_lines, + IsSupersetOf({header, total_cycles_profile_line, dot_profile_line, + add_profile_line, tuple_profile_line})); } } // namespace diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 07136d6a746604f148b62e48480a4fa0d253927d..980e0eec9e23b15a97b826067bac08053a437712 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -124,7 +124,6 @@ cc_library( srcs = ["xla_tensor.cc"], hdrs = ["xla_tensor.h"], deps = [ - ":common", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:shaped_buffer", @@ -176,11 +175,11 @@ cc_library( "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:no_op", "//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:variable_ops", - "@com_google_absl//absl/memory", ], ) @@ -217,6 +216,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:gpu_runtime", @@ -261,6 +261,7 @@ cc_library( name = "create_xla_launch_op", srcs = [ "create_xla_launch_op.cc", + "create_xla_launch_op.h", ], deps = [ ":common", @@ -270,6 +271,27 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "create_xla_launch_op_test", + srcs = [ + "create_xla_launch_op.h", + "create_xla_launch_op_test.cc", + ], + deps = [ + ":create_xla_launch_op", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc index 9a2bb0007527557f79b70ad2b9c9576af2ab10ea..b17ff589e2597f8d1b5e61f4eaaed7d6ebe6214c 100644 --- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc @@ -40,7 +40,7 @@ static Status BuildLaunchNode( Graph* graph, Node** node) { NodeDef def; def.set_name(graph->NewName(nodename)); - def.set_op("_XlaLaunch"); + def.set_op("XlaLaunch"); def.set_device(device_name); AddNodeAttr("Tconstants", constant_dtypes, &def); AddNodeAttr("Targs", arg_dtypes, &def); @@ -79,7 +79,7 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { node->input_types().begin() + num_constant_args, node->input_types().begin() + num_constant_args + num_nonconst_args); - // Build a _XlaLaunch operator to execute the function body. + // Build a XlaLaunch operator to execute the function body. Node* launch_node; TF_RETURN_IF_ERROR(BuildLaunchNode( graph->NewName(node->name()), node->type_string(), node->def().attr(), diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index 18d901323f108505979be484c2bfad5998ab0748..731b8ebfdc6262500940274c94a03ae7c0376096 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/jit/create_xla_launch_op.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/kernels/xla_launch_op.h" @@ -21,82 +22,194 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace { -// Givens a NodeDef 'ndef' and the function library runtime 'flr', if -// 'ndef' is a call to a compilable function defined in 'flr', returns OK -// and fills in 'kernel' with a XlaLaunchOp kernel which computes the -// node. Otherwise, returns a non-OK. +// Utility which searches for values in a sorted list by scanning over it once. +// No matter how many times ScanForValue is called, the list is scanned at most +// once. However, if a call to ScanForValue skips over a value, that value is +// not revisited in future calls to ScanForValue, so callers must take +// care to order their calls. // -// This routine is here so that FunctionLibraryRuntime can jit a -// specific function call as requested. -Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef, - std::unique_ptr* kernel) { - bool xla_compile = false; - if (!flr->GetFunctionLibraryDefinition() - ->GetAttr(ndef, kXlaCompileAttr, &xla_compile) - .ok() || - !xla_compile) { - // Not marked as _XlaCompile=true. - return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op()); +// Useful for merging multiple sorted lists in O(n) time. +class SinglePassSearch { + public: + // Creates a SinglePassSearch object that can be used to search in `values`. + // Does not take ownership of `values`. `values` must outlive this. + // `values` must be sorted. + explicit SinglePassSearch(const std::vector* values) + : current_index_(0), values_(values) {} + + // Scans forward in the vector looking for "value", updating the internal + // position in to the vector. + // Returns true iff the vector contains the given value at or after current + // position. + // Not thread-safe. + bool ScanForValue(int value) { + while (current_index_ < values_->size() && + (*values_)[current_index_] <= value) { + if ((*values_)[current_index_] == value) { + current_index_++; + return true; + } + current_index_++; + } + return false; } - // Make sure that kernels have been registered on the JIT device. - XlaOpRegistry::RegisterCompilationKernels(); - if (!IsCompilable(flr, ndef)) { - // ndef is calling a function that XLA can't compile. - return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString()); + + private: + int current_index_; + const std::vector* values_; +}; + +Status CompilationRequested(const FunctionLibraryRuntime& flr, + const NodeDef& node_def) { + bool xla_compile = false; + // Check if op is marked _XlaCompile=true. + Status status = flr.GetFunctionLibraryDefinition()->GetAttr( + node_def, kXlaCompileAttr, &xla_compile); + if (!status.ok() || !xla_compile) { + if (VLOG_IS_ON(3)) { + if (!status.ok()) { + VLOG(3) << "No " << kXlaCompileAttr << " attr defined for " + << node_def.op() << ". status=" << status.ToString(); + } else { + VLOG(3) << node_def.op() << " is explicitly marked not to be compiled"; + } + } + return Status(error::INVALID_ARGUMENT, ""); } + return Status::OK(); +} + +// Given a FunctionLibraryRuntime and a NodeDef calling a function in the +// runtime, returns this function's body in `fbody` as well as the indices +// of its constant and resource arguments. +// `fbody` is owned by `flr`. +// `constant_arg_indices` and `resource_arg_indices` should be empty vector. +// They are sorted in ascending order on this function's return. +Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, + const NodeDef& node_def, + const FunctionBody** fbody, + std::vector* constant_arg_indices, + std::vector* resource_arg_indices) { FunctionLibraryRuntime::Handle handle; - // If ndef is not instantiable, e.g., the function does not exist, + // If node_def is not instantiable, e.g., the function does not exist, // simply bail out. TF_RETURN_IF_ERROR( - flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle)); - const FunctionBody* fbody = flr->GetFunctionBody(handle); - CHECK(fbody); // Can't be nullptr since we just instantiated it. - std::vector const_args(fbody->arg_types.size()); + flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle)); + *fbody = flr->GetFunctionBody(handle); + CHECK(*fbody); // Can't be nullptr since we just instantiated it. + const DataTypeVector& arg_types = (*fbody)->arg_types; + std::vector const_args(arg_types.size()); // If we can't analyze the const args. Bail out. - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*(fbody->graph), &const_args)); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args)); for (int i = 0; i < const_args.size(); ++i) { if (const_args[i]) { - // There is a const arg. Bail out. - return errors::InvalidArgument("Const arg: ", i, " in ", - DebugString(fbody->fdef)); + constant_arg_indices->push_back(i); + } + } + + // There can be hundreds of resource variables. Reserve the space for them. + // We don't reserve for constants above as they are usually few. + resource_arg_indices->reserve(arg_types.size()); + for (int i = 0; i < arg_types.size(); ++i) { + if (arg_types[i] == DT_RESOURCE) { + resource_arg_indices->push_back(i); } } - NodeDef launch_def; - launch_def.set_name(ndef.name()); - launch_def.set_op("_XlaLaunch"); - launch_def.set_device(flr->device()->name()); - AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def); - AddNodeAttr("Nresources", 0, &launch_def); - AddNodeAttr("Targs", fbody->arg_types, &launch_def); - AddNodeAttr("Tresults", fbody->ret_types, &launch_def); - NameAttrList func; - func.set_name(ndef.op()); - *(func.mutable_attr()) = ndef.attr(); - AddNodeAttr("function", func, &launch_def); - - // TODO(b/32387911): Handles the host memory types across function - // calls properly. For now, we assume all inputs and outputs are on - // the device memory. + return Status::OK(); +} + +} // namespace + +Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, + std::unique_ptr* kernel) { + TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def)); + + VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString(); + + // Make sure that kernels have been registered on the JIT device. + XlaOpRegistry::RegisterCompilationKernels(); + if (!IsCompilable(flr, node_def)) { + // node_def is calling a function that XLA can't compile. + return errors::InvalidArgument("Not compilable: ", + node_def.ShortDebugString()); + } + + // Get function body, constant args, and resource args. + const FunctionBody* fbody = nullptr; + std::vector constant_arg_indices; + std::vector resource_arg_indices; + TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( + flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices)); + + // Set input and output memory types. MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY); + // These indices are used only for optimization purposes. They allow us + // to loop over constant_arg_indices and resource_arg_indices only once + // while iterating over all the function arguments checking if it is a + // resource or a constant. + // The reason we optimized this code is because functions can have a lot of + // captured arguments. For example, the backward pass of ResNet50 takes in all + // 214 variables and a similar number of activations. + SinglePassSearch constants_search(&constant_arg_indices); + SinglePassSearch resources_search(&resource_arg_indices); + for (int i = 0; i < fbody->arg_types.size(); ++i) { + if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) { + // Compile-time constants and resource handles are expected to be in + // host memory. + input_memory_types[i] = HOST_MEMORY; + } + } + // One might wonder, about the case where a compile-time constant argument + // (which must be in host memory) is also used as an input into an op, + // e.g. Add, that expects its inputs in device memory. Here is how it + // works now. + // First, what do we mean by "op expects an input in XYZ memory"? + // There are two types of "ops" here: the tf2xla kernel and the HLO + // computation it builds. The tf2xla kernel needs to retrieve the actual + // numeric value of the compile-time constant tensors, so it really expects + // them to be on in host memory. However, for other inputs, it refers to them + // using xla::ComputationDataHandle, which is just a symbolic handle that + // xla::ComputationBuilder assigns. How does this handle gets assigned for + // constant arguments? Even constant arguments get an _Arg node in the graph + // instatiated for Function compilation. The tf2xla kernel for constant _Arg + // nodes takes the constant value, converts it to XlaLiteral, and feeds it + // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This + // constant XlaLiteral is included in the HLO graph, and subsequently, in + // the actual executable, which is copied to the device before being + // executed. Thus, when this executable runs, the constant is available in + // device memory. + + // XlaLaunch kernel keeps all outputs (including constants, which it copies), + // in device memory MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY); + // Create the kernel. + NameAttrList function; + function.set_name(node_def.op()); + *(function.mutable_attr()) = node_def.attr(); + Device* dev = flr->device(); Status s; OpKernelConstruction construction( DeviceType(dev->device_type()), dev, - dev->GetAllocator(AllocatorAttributes()), &launch_def, + dev->GetAllocator(AllocatorAttributes()), &node_def, &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types, fbody->ret_types, output_memory_types, flr->graph_def_version(), &s); - kernel->reset(new XlaLocalLaunchOp(&construction)); + + *kernel = MakeUnique(&construction, constant_arg_indices, + resource_arg_indices, function); return s; } +namespace { + bool RegisterLaunchOpCreator() { RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp); return true; diff --git a/tensorflow/compiler/jit/create_xla_launch_op.h b/tensorflow/compiler/jit/create_xla_launch_op.h new file mode 100644 index 0000000000000000000000000000000000000000..98a22e351532c197c69c5ea908305d885fd2c9d0 --- /dev/null +++ b/tensorflow/compiler/jit/create_xla_launch_op.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ +#define TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class FunctionLibraryRuntime; +class OpKernel; + +// Given a NodeDef 'node_def' and the function library runtime 'flr', if +// 'node_def' is a call to a compilable function defined in 'flr', returns OK +// and fills in 'kernel' with a XlaLaunchOp kernel which computes the +// node. Otherwise, returns a non-OK. +Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, + std::unique_ptr* kernel); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b75ab486b80e098bc0a59f9ea8cdbaa23a28fef9 --- /dev/null +++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc @@ -0,0 +1,145 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/create_xla_launch_op.h" + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { + +NodeDef ToNodeDef(const string& text) { + NodeDef node_def; + EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def)); + return node_def; +} + +// Create a FunctionDef that takes one resource and one regular param +FunctionDef XTimesY() { + return FunctionDefHelper::Define( + // Name + "XTimesY", + // Args + {"x: float", "y: resource"}, + // Return values + {"z: float"}, + // Attr def + {}, + // Nodes + { + {{"y0"}, "ReadVariableOp", {"y"}, {{"dtype", DT_FLOAT}}}, + {{"z"}, "Mul", {"x", "y0"}, {{"T", DT_FLOAT}}}, + }); +} + +class CreateXlaLaunchOpTest : public ::testing::Test { + protected: + void Init(const std::vector& flib) { + SessionOptions options; + auto* device_count = options.config.mutable_device_count(); + device_count->insert({"CPU", 1}); + TF_CHECK_OK(DeviceFactory::AddDevices( + options, "/job:localhost/replica:0/task:0", &devices_)); + + FunctionDefLibrary proto; + for (const auto& fdef : flib) { + *(proto.add_function()) = fdef; + } + lib_def_ = + MakeUnique(OpRegistry::Global(), proto); + OptimizerOptions opts; + device_mgr_ = MakeUnique(devices_); + pflr_ = MakeUnique( + device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), + opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); + flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); + } + + FunctionLibraryRuntime* flr_; + std::vector devices_; + std::unique_ptr device_mgr_; + std::unique_ptr lib_def_; + std::unique_ptr pflr_; + + std::unique_ptr kernel_; +}; + +AttrValue BoolAttr(bool b) { + AttrValue v; + v.set_b(b); + return v; +} + +TEST_F(CreateXlaLaunchOpTest, OneFloatOneResourceArgument) { + FunctionDef fdef = XTimesY(); + (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true); + Init({fdef}); + + Status status = CreateXlaLaunchOp( + flr_, ToNodeDef(R"pb( + name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b' + )pb"), &kernel_); + ASSERT_TRUE(status.ok()) << status.ToString(); + + EXPECT_EQ("XTimesY", kernel_->name()); + EXPECT_EQ("XTimesY", kernel_->type_string()); + + EXPECT_EQ(2, kernel_->num_inputs()); + EXPECT_EQ(DT_FLOAT, kernel_->input_type(0)); + EXPECT_EQ(DT_RESOURCE, kernel_->input_type(1)); + EXPECT_EQ(DEVICE_MEMORY, kernel_->input_memory_types()[0]); + EXPECT_EQ(HOST_MEMORY, kernel_->input_memory_types()[1]); + + EXPECT_EQ(1, kernel_->num_outputs()); + EXPECT_EQ(DT_FLOAT, kernel_->output_type(0)); + EXPECT_EQ(DEVICE_MEMORY, kernel_->output_memory_types()[0]); +} + +TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrNotSet) { + FunctionDef fdef = XTimesY(); + Init({fdef}); + + Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), &kernel_); + EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString(); +} + +TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrIsSetToFalse) { + FunctionDef fdef = XTimesY(); + (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false); + Init({fdef}); + + Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), &kernel_); + EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 34be4409a381197d2191e083727aa8d48ab8cd63..5fee36f022a7515504cb6faa5cca658481b784c5 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -80,7 +80,7 @@ Status EncapsulateSubgraphsInFunctions( std::unique_ptr* graph_out, FunctionLibraryDefinition* library); // The attribute that marks function calls produced by the encapsulate -// subgraphs pass and that should in turn be compiled via _XlaLaunch operators. +// subgraphs pass and that should in turn be compiled via XlaLaunch operators. extern const char* const kXlaCompiledKernelAttr; // Does `node` have the kXlaCompiledKernelAttr attribute? diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc index bc68afb322b5cfc814ce0537254ba14053ae4550..805bbc62c1e2e877de87ab8faf3d60b829743df8 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc @@ -354,6 +354,16 @@ bool GraphCycles::IsReachableNonConst(int32 x, int32 y) { return reachable; } +bool GraphCycles::CanContractEdge(int32 a, int32 b) { + CHECK(HasEdge(a, b)) << "No edge exists from " << a << " to " << b; + RemoveEdge(a, b); + bool reachable = IsReachableNonConst(a, b); + // Restore the graph to its original state. + InsertEdge(a, b); + // If reachable, then contracting edge will cause cycle. + return !reachable; +} + bool GraphCycles::ContractEdge(int32 a, int32 b) { CHECK(HasEdge(a, b)); RemoveEdge(a, b); @@ -388,4 +398,8 @@ std::unordered_set GraphCycles::Successors(int32 node) { return rep_->nodes_[node]->out; } +std::unordered_set GraphCycles::Predecessors(int32 node) { + return rep_->nodes_[node]->in; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.h b/tensorflow/compiler/jit/graphcycles/graphcycles.h index d11d6e27b1b7bb514127e16a9be21f044100d885..44448fa3d787d0785a797d40ed1b968438a903c9 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.h +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.h @@ -85,6 +85,9 @@ class GraphCycles { // and returns false. bool ContractEdge(int32 a, int32 b); + // Return true if can contract edge, otherwise return false. + bool CanContractEdge(int32 a, int32 b); + // Return whether dest_node is reachable from source_node // by following edges. bool IsReachable(int32 source_node, int32 dest_node) const; @@ -115,6 +118,7 @@ class GraphCycles { bool CheckInvariants() const; std::unordered_set Successors(int32 node); + std::unordered_set Predecessors(int32 node); // ---------------------------------------------------- struct Rep; diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc index e47b782207e9122740fe9d5daf1fa0dbaeb47754..274f5938a1228baf68ad4d8e1a7b13f276321d27 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc @@ -494,6 +494,20 @@ TEST_F(GraphCyclesTest, ContractEdge) { EXPECT_TRUE(g_.HasEdge(1, 4)); } +TEST_F(GraphCyclesTest, CanContractEdge) { + ASSERT_TRUE(AddEdge(1, 2)); + ASSERT_TRUE(AddEdge(1, 3)); + ASSERT_TRUE(AddEdge(2, 3)); + ASSERT_TRUE(AddEdge(2, 4)); + ASSERT_TRUE(AddEdge(3, 4)); + + EXPECT_FALSE(g_.CanContractEdge(1, 3)); + EXPECT_FALSE(g_.CanContractEdge(2, 4)); + EXPECT_TRUE(g_.CanContractEdge(1, 2)); + EXPECT_TRUE(g_.CanContractEdge(2, 3)); + EXPECT_TRUE(g_.CanContractEdge(3, 4)); +} + static void BM_StressTest(int iters, int num_nodes) { while (iters > 0) { tensorflow::GraphCycles g; diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 049d170fa48928474b894f2d0e1f2243c5f87275..27287e0f9637929b2e04c6a76de19c2785ec357e 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -39,15 +39,15 @@ limitations under the License. namespace tensorflow { -XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) - : OpKernel(ctx), device_type_(ctx->device_type()) { - const NameAttrList* func; - OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func)); - function_ = *func; - DataTypeVector constant_types; - OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types)); - num_constant_args_ = constant_types.size(); - OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_)); +XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector& constants, + const std::vector& resources, + const NameAttrList& function) + : OpKernel(ctx), + constants_(constants), + resources_(resources), + device_type_(ctx->device_type()), + function_(function) { if (device_type_ == DeviceType(DEVICE_CPU)) { platform_id_ = se::host::kHostPlatformId; } else if (device_type_ == DeviceType(DEVICE_GPU)) { @@ -57,8 +57,8 @@ XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) } } -Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** cache) { +Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** cache) { const XlaDevice::Metadata* metadata; Status s = XlaDevice::GetMetadata(ctx, &metadata); if (s.ok()) { @@ -90,8 +90,8 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, return Status::OK(); } -void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { - VLOG(1) << "XlaLocalLaunchOp::Compute " +void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { + VLOG(1) << "XlaLocalLaunchOpBase::Compute " << Canonicalize(function_.name(), AttrSlice(&function_.attr())); // We store information about the JIT-compiled XLA computation // in the ResourceMgr. @@ -112,7 +112,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { // this is more obviously correct.) core::ScopedUnref cache_ref(cache); - const XlaDevice::Metadata* metadata; + const XlaDevice::Metadata* metadata = nullptr; Status s = XlaDevice::GetMetadata(ctx, &metadata); bool allocate_xla_tensors = s.ok(); @@ -124,7 +124,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { } std::map variables = - SnapshotResourceVariables(ctx, num_resource_args_); + SnapshotResourceVariables(ctx, resources_); xla::LocalClient* client = static_cast(cache->client()); @@ -153,25 +153,27 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); options.device_allocator = xla_allocator; - // TODO(b/77671268): We don't set variable_representation_shape_fn here. This - // is restricted to Variables, but we need something like this to apply to - // normal Tensors too. + if (metadata) { + options.shape_representation_fn = metadata->shape_representation_fn(); + } const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; std::map constant_args; - for (int i = 0; i < num_constant_args_; ++i) { + for (int i : constants_) { constant_args.insert({i, ctx->input(i)}); } - OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args, - variables, ctx, &kernel, &executable, - /*compile_options=*/nullptr)); + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; + OP_REQUIRES_OK( + ctx, cache->Compile(options, function_, constant_args, variables, ctx, + &kernel, &executable, &compile_options)); VLOG(1) << "Executing XLA Computation..."; - XlaComputationLaunchContext launch_context( - num_resource_args_, client, xla_allocator, allocate_xla_tensors); + XlaComputationLaunchContext launch_context(client, xla_allocator, + allocate_xla_tensors); launch_context.PopulateInputs(ctx, kernel, variables); // Execute the computation. @@ -194,14 +196,69 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { VLOG(1) << "Done"; } +namespace { + +// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that +// in error case, it returns RET instead of void. +#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ + do { \ + ::tensorflow::Status _s(__VA_ARGS__); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ + return RET; \ + } \ + } while (0) + +// Helper static functions to construct parameters for +// XlaLocalLaunchBase constructor from OpKernelConstruction. +std::vector ConstantsVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Tconstants", &constant_types)); + std::vector constants(constant_types.size()); + std::iota(constants.begin(), constants.end(), 0); + return constants; +} + +std::vector ResourcesVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Tconstants", &constant_types)); + + DataTypeVector arg_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Targs", &arg_types)); + + int num_resources; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Nresources", &num_resources)); + + std::vector resources(num_resources); + std::iota(resources.begin(), resources.end(), + constant_types.size() + arg_types.size()); + return resources; +} + +NameAttrList FunctionAttr(OpKernelConstruction* ctx) { + const NameAttrList* func; + OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func)); + return *func; +} + +#undef OP_REQUIRES_OK_RETURN +} // namespace + +XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) + : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx), + FunctionAttr(ctx)) {} + XlaLocalLaunchOp::~XlaLocalLaunchOp() { VLOG(1) << "XlaLocalLaunchOp destroyed"; } -REGISTER_KERNEL_BUILDER(Name("_XlaLaunch").Device(DEVICE_CPU), - XlaLocalLaunchOp); +REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); -REGISTER_KERNEL_BUILDER(Name("_XlaLaunch") +REGISTER_KERNEL_BUILDER(Name("XlaLaunch") .Device(DEVICE_GPU) .HostMemory("constants") .HostMemory("resources"), diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h index 8f8e646f0ff6d94dfdf56721cacfce7fa658beb6..8dfc4b382d51151b6383fe7dd75429f3124d39be 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.h +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h @@ -26,6 +26,41 @@ limitations under the License. namespace tensorflow { +// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. +// The only difference is that it does not require arguments to follow +// the "constants, then regular args, then resources" order. +// It takes vectors of constant and resource arguments explicitly. +// It does not have corresponding OpDef because it is never present +// in the GraphDef. +// Currently, it is used by eager runtime. FunctionLibraryRuntime creates +// this kernel when asked to create a kernel for an XLA-compiled function. +class XlaLocalLaunchBase : public OpKernel { + public: + XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector& constants, + const std::vector& resources, + const NameAttrList& function); + XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete; + XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete; + ~XlaLocalLaunchBase() override = default; + + void Compute(OpKernelContext* ctx) override; + + protected: + // Builds a XlaCompilationCache class suitable for the current device. + Status BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** cache); + + // Indexes of compile-time constant inputs + std::vector constants_; + // Indexes of resource inputs + std::vector resources_; + + DeviceType device_type_; + NameAttrList function_; + se::Platform::Id platform_id_; +}; + // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph // which will be compiled and executed using XLA. The XlaLocalLaunchOp is // responsible for handling interactions with the TensorFlow executor. @@ -35,26 +70,12 @@ namespace tensorflow { // XlaLocalLaunchOp uses xla::LocalClient::Compile() and // xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device // memory. -class XlaLocalLaunchOp : public OpKernel { +class XlaLocalLaunchOp : public XlaLocalLaunchBase { public: explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); ~XlaLocalLaunchOp() override; - void Compute(OpKernelContext* ctx) override; - private: - // Builds a XlaCompilationCache class suitable for the current device. - Status BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** compiler); - - DeviceType device_type_; - NameAttrList function_; - int num_constant_args_; - // Number of resource variable arguments. - int num_resource_args_; - - se::Platform::Id platform_id_; - TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); }; diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index 07320b43dab790e6cda5e85688bdacf48a35adc4..f2473d98ffd5dae55983e601b8d2d65af6a6d54c 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -17,7 +17,7 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("_XlaLaunch") +REGISTER_OP("XlaLaunch") .Input("constants: Tconstants") .Attr("Tconstants: list(type) >= 0") .Input("args: Targs") @@ -28,7 +28,7 @@ REGISTER_OP("_XlaLaunch") .Attr("Tresults: list(type) >= 0") .Attr("function: func") // XLA random-number generation ops are stateful. - // TODO(phawkins): create stateful and non-stateful variants of _XlaLaunch. + // TODO(phawkins): create stateful and non-stateful variants of XlaLaunch. .SetIsStateful() .Doc("XLA Launch Op. For use by the XLA JIT only."); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 6430975335f5eef5b53c80213e6090ffd6166a91..7ed609c43748062656b631243c01d790519c54fd 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -122,8 +122,7 @@ Status XlaCompilationCache::BuildSignature( namespace { -// Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch -// op. +// Builds a XlaCompiler::Argument vector from the arguments to the XlaLaunch op. Status BuildArguments(const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 60458f6f3314b2c3b65be1c90e051b2a670383bc..ab644ff5a61c407b246b97af5328bf5cd8c1893b 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -48,13 +48,12 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, const XlaCompiler::CompilationResult* result, xla::LocalExecutable* executable) { std::map variables = GetVariables(ctx); - int64 num_resource_args = variables.size(); xla::LocalClient* client = metadata.client(); // Builds an XLA allocator for the device. XlaComputationLaunchContext launch_context( - num_resource_args, client, client->backend().memory_allocator(), true); + client, client->backend().memory_allocator(), true); launch_context.PopulateInputs(ctx, result, variables); @@ -157,11 +156,14 @@ Status XlaCompileOnDemandOp::Compile( options.client = metadata.client(); options.flib_def = new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{}); + options.shape_representation_fn = metadata.shape_representation_fn(); + + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; std::map variable_args = GetVariables(ctx); return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, - result, executable, - /*compile_options=*/nullptr); + result, executable, &compile_options); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h index 23c6f3903f841a6c39104983c6f7f409757a7319..7cc3d0e007ba2974fbfbe6fbabc4aa08f9fa910f 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.h +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -29,11 +29,8 @@ limitations under the License. namespace tensorflow { // An OpKernel that compiles an op to an XLA computation and runs it. Unlike -// _XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a +// XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a // vanilla TensorFlow op as long as the bridge supports it. -// -// Importantly _XlaLaunch assumes all input and output tensors are on the host, -// whereas XlacompileOnDemandOp works with tensors in device memory. class XlaCompileOnDemandOp : public OpKernel { public: explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index bc07dbd7bdf005fde781f7a1e6775080e363abfb..ea9e0366043a4a64bfe43703c55d4470693bbac8 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -50,10 +50,11 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, (void)registrations; std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, - DEVICE_CPU_XLA_JIT, options, name_prefix, - registration, - /*transfer_as_literal=*/false, &device)); + TF_RETURN_IF_ERROR( + XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options, + name_prefix, registration, + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 70263b1ff936757101a3c47d192b2ba58271dc79..f13b46c532e6008477849f2e06887901c90038ab 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device_context.h" @@ -49,6 +48,7 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/ptr_util.h" #include "tensorflow/core/util/stream_executor_util.h" namespace tensorflow { @@ -110,7 +110,9 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( const string& jit_device_name, const SessionOptions& options, const string& name_prefix, const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, std::unique_ptr* device) { + bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + std::unique_ptr* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" << device_ordinal; @@ -129,17 +131,19 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), strings::StrCat("device: ", device_name, " device")); - device->reset(new XlaDevice(options, attrs, device_ordinal, - DeviceType(jit_device_name), - platform.ValueOrDie(), transfer_as_literal)); + device->reset(new XlaDevice( + options, attrs, device_ordinal, DeviceType(jit_device_name), + platform.ValueOrDie(), transfer_as_literal, shape_representation_fn)); return Status::OK(); } -XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform, - const DeviceType& device_type) +XlaDevice::Metadata::Metadata( + int device_ordinal, se::Platform* platform, const DeviceType& device_type, + XlaCompiler::ShapeRepresentationFn shape_representation_fn) : device_ordinal_(device_ordinal), device_type_(device_type), - platform_(platform) {} + platform_(platform), + shape_representation_fn_(std::move(shape_representation_fn)) {} int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; } @@ -170,17 +174,20 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { return Status::OK(); } -XlaDevice::XlaDevice(const SessionOptions& options, - const DeviceAttributes& attrs, int device_ordinal, - const DeviceType& jit_device_name, se::Platform* platform, - bool transfer_as_literal) +XlaDevice::XlaDevice( + const SessionOptions& options, const DeviceAttributes& attrs, + int device_ordinal, const DeviceType& jit_device_name, + se::Platform* platform, bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn) : LocalDevice(options, attrs), - xla_metadata_(device_ordinal, platform, jit_device_name), + xla_metadata_(device_ordinal, platform, jit_device_name, + shape_representation_fn), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(nullptr), platform_(platform), - transfer_as_literal_(transfer_as_literal) { + transfer_as_literal_(transfer_as_literal), + shape_representation_fn_(shape_representation_fn) { VLOG(1) << "Created XLA device " << jit_device_name; } @@ -230,10 +237,10 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() { GetAllocator({}); // XlaDevice owns both gpu_device_info_ and // gpu_device_info_->default_context. - gpu_device_info_ = absl::make_unique(); + gpu_device_info_ = MakeUnique(); gpu_device_info_->stream = stream; - gpu_device_info_->default_context = - new XlaDeviceContext(stream, client(), transfer_as_literal_); + gpu_device_info_->default_context = new XlaDeviceContext( + stream, client(), transfer_as_literal_, shape_representation_fn_); set_tensorflow_gpu_device_info(gpu_device_info_.get()); } @@ -247,7 +254,8 @@ Status XlaDevice::FillContextMap(const Graph* graph, TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); // Call GetAllocator for the side-effect of ensuring the allocator is created. GetAllocator({}); - auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_); + auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_, + shape_representation_fn_); for (Node* n : graph->nodes()) { VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); ctx->Ref(); @@ -294,7 +302,8 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); Notification n; TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - XlaTransferManager manager(stream, client(), transfer_as_literal_); + XlaTransferManager manager(stream, client(), transfer_as_literal_, + shape_representation_fn_); manager.CopyCPUTensorToDevice(&parsed, this, ©, [&n, &status](const Status& s) { status = s; diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 3ae87308cc7cffa916e178893df70a3f314b11b0..d5d345d43b16c43c7a202791b2604b39d29e8cdb 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -17,8 +17,7 @@ limitations under the License. // runtime. // // Operators assigned to an XlaDevice are compiled into XLA computations. -// Tensors on an XlaDevice are thin wrappers around XLA GlobalDataHandles; state -// is managed by XLA. +// Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers. // // XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU), // under different names (e.g., XLA_CPU or XLA_GPU). @@ -27,6 +26,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -50,7 +50,8 @@ class XlaDevice : public LocalDevice { class Metadata { public: Metadata(int device_ordinal, se::Platform* platform, - const DeviceType& device_type); + const DeviceType& device_type, + XlaCompiler::ShapeRepresentationFn shape_representation_fn); // The index of the device on this host. int device_ordinal() const; @@ -58,11 +59,15 @@ class XlaDevice : public LocalDevice { se::Platform* platform() const; xla::LocalClient* client() const; const DeviceType& jit_device_type() const; + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const { + return shape_representation_fn_; + } private: const int device_ordinal_; const DeviceType device_type_; se::Platform* platform_; // Not owned. + XlaCompiler::ShapeRepresentationFn shape_representation_fn_; TF_DISALLOW_COPY_AND_ASSIGN(Metadata); }; @@ -76,16 +81,19 @@ class XlaDevice : public LocalDevice { // 'transfer_as_literal' is true if device<->host transfers must be done using // XLA's TransferLiteral{To,From}Device interface. If false, we can use // ThenMemcpy instead. - static Status Create(const string& platform_name, const string& device_name, - int device_ordinal, const string& jit_device_name, - const SessionOptions& options, const string& name_prefix, - const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, - std::unique_ptr* device); + static Status Create( + const string& platform_name, const string& device_name, + int device_ordinal, const string& jit_device_name, + const SessionOptions& options, const string& name_prefix, + const XlaOpRegistry::DeviceRegistration& registration, + bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + std::unique_ptr* device); XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, - se::Platform* platform, bool transfer_as_literal); + se::Platform* platform, bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn); ~XlaDevice() override; Allocator* GetAllocator(AllocatorAttributes attr) override; @@ -116,8 +124,8 @@ class XlaDevice : public LocalDevice { // The name of the device that is used to compile Ops for this XlaDevice. DeviceType jit_device_name_; // Memory allocator associated with this device. - Allocator* xla_allocator_; // Not owned. - se::Platform* platform_; // Not owned. + Allocator* xla_allocator_; // Not owned. + se::Platform* platform_; // Not owned. // Stream associated with this device. Operations enqueued on this // stream are executed on the device. Operations include data // copying back and forth between CPU and the device, and @@ -126,6 +134,7 @@ class XlaDevice : public LocalDevice { // Must we use XLA's transfer manager for correct host<->device transfers? if // false, we can use ThenMemcpy() instead. bool transfer_as_literal_; + XlaCompiler::ShapeRepresentationFn shape_representation_fn_; // If set, holds default device context (that we must Unref) // and its stream. diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index bf8c1886a022310eeaacdf69463f575a393dd8d0..ff30b62bad782f281bcd25275521ed8b0c4c0bfd 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -47,13 +47,14 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } -XlaTransferManager::XlaTransferManager(se::Stream* stream, - xla::LocalClient* client, - bool transfer_as_literal) +XlaTransferManager::XlaTransferManager( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn) : stream_(stream), client_(client), transfer_manager_(client->backend().transfer_manager()), - transfer_as_literal_(transfer_as_literal) {} + transfer_as_literal_(transfer_as_literal), + shape_representation_fn_(std::move(shape_representation_fn)) {} Status XlaTransferManager::TransferLiteralToDevice( const Tensor& host_tensor, Tensor* device_tensor) const { @@ -76,7 +77,15 @@ Status XlaTransferManager::TransferLiteralFromDevice( transfer_manager_->TransferLiteralFromDevice( stream_->parent(), shaped_buffer)); VLOG(1) << "Transfer from device as literal: " << literal->ToString(); - return LiteralToHostTensor(*literal, host_tensor->dtype(), host_tensor); + Tensor tensor; + TF_RETURN_IF_ERROR( + LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); + // Reshape the tensor back to its declared shape. + if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) { + return errors::Internal( + "Tensor::CopyFrom failed when copying from XLA device to CPU"); + } + return Status::OK(); } void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, @@ -96,9 +105,17 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); CHECK(xla_tensor); + + TensorShape shape; + if (shape_representation_fn_) { + shape = shape_representation_fn_(device_tensor->shape(), + device_tensor->dtype()); + } else { + shape = device_tensor->shape(); + } if (!xla_tensor->has_shaped_buffer()) { Status s = xla_tensor->AllocateShapedBuffer( - device_tensor->dtype(), device_tensor->shape(), client_, + device_tensor->dtype(), shape, client_, stream_->parent()->device_ordinal()); if (!s.ok()) { done(s); @@ -106,12 +123,18 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } } - se::DeviceMemoryBase dev_dst_ptr = - XlaTensor::DeviceMemoryFromTensor(*device_tensor); Status status; if (transfer_as_literal_) { - status = TransferLiteralToDevice(*cpu_tensor, device_tensor); + Tensor reshaped_cpu_tensor; + if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) { + done(errors::Internal( + "Tensor::CopyFrom failed when copying from CPU to XLA device")); + return; + } + status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); } else { + se::DeviceMemoryBase dev_dst_ptr = + XlaTensor::DeviceMemoryFromTensor(*device_tensor); stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); // TODO(hpucha): Make this asynchronous. Status block_status = stream_->BlockHostUntilDone(); @@ -171,9 +194,11 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, done(Status::OK()); } -XlaDeviceContext::XlaDeviceContext(se::Stream* stream, xla::LocalClient* client, - bool transfer_as_literal) - : manager_(stream, client, transfer_as_literal) {} +XlaDeviceContext::XlaDeviceContext( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn) + : manager_(stream, client, transfer_as_literal, + std::move(shape_representation_fn)) {} void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index d7f5f1d208989256f8043d2e6d93cf9bd89333b2..9af9655868448ce5116db3611c5f88339135947e 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/framework/allocator.h" @@ -45,8 +46,9 @@ class XlaDeviceAllocator : public Allocator { // Helper class for managing data transfers between host and XLA devices. class XlaTransferManager { public: - explicit XlaTransferManager(se::Stream* stream, xla::LocalClient* client, - bool transfer_as_literal); + explicit XlaTransferManager( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; @@ -69,7 +71,8 @@ class XlaTransferManager { // Transfer manager, for marshalling data to and from the device. xla::TransferManager* transfer_manager_; // True if we must use XLA's TransferManager for correct device transfers. - bool transfer_as_literal_; + const bool transfer_as_literal_; + const XlaCompiler::ShapeRepresentationFn shape_representation_fn_; }; // DeviceContext for operators assigned to XlaDevice devices. The @@ -77,8 +80,9 @@ class XlaTransferManager { // wraps the methods in XlaTransferManager. class XlaDeviceContext : public DeviceContext { public: - explicit XlaDeviceContext(se::Stream* stream, xla::LocalClient* client, - bool transfer_as_literal); + explicit XlaDeviceContext( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 498d25cf566a91f68e5eb1ac312e17900471aeca..9c00a0682ccdc08e7bb09e32d32f01e87e7aaf8d 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/constant_op.h" #include "tensorflow/core/kernels/control_flow_ops.h" +#include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_op.h" #include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/kernels/sendrecv_ops.h" @@ -32,7 +33,7 @@ namespace tensorflow { // Dummy OpKernel, used for kernels assigned to an XLA device that should be // compiled. Should never be called at runtime since such ops should be -// rewritten to a _XlaLaunch op. If it is called, it means the placer placed an +// rewritten to a XlaLaunch op. If it is called, it means the placer placed an // operator on an XLA device but the compiler did not compile it. class XlaDeviceDummyOp : public OpKernel { public: @@ -41,7 +42,7 @@ class XlaDeviceDummyOp : public OpKernel { }; #define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \ - REGISTER_KERNEL_BUILDER(Name("_XlaLaunch") \ + REGISTER_KERNEL_BUILDER(Name("XlaLaunch") \ .Device(DEVICE) \ .HostMemory("constants") \ .HostMemory("resources"), \ @@ -63,6 +64,9 @@ class XlaDeviceDummyOp : public OpKernel { ConstantOp); \ REGISTER_KERNEL_BUILDER( \ Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("IdentityN").Device(DEVICE).TypeConstraint("T", TYPES), \ + IdentityNOp); \ REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp); \ REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \ PlaceholderOp); \ diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index a8afbf9dcd736bb292b7c5f52c7cce2b47fb85b6..26842fbe5cc110fa9ce7a2767d245484fd67556d 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -48,7 +48,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, Status status = XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix, registration, - /*transfer_as_literal=*/false, &device); + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, &device); if (!status.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. VLOG(1) << "Failed to create XLA_GPU device: " << status; diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 9e098c46f422b436c722bb909dc58930ab7c0ef6..4146996f6346446e715ffb225882cfb20359dae1 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -48,10 +48,11 @@ Status XlaInterpreterDeviceFactory::CreateDevices( registration.compile_resource_ops = true; std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create("Interpreter", DEVICE_XLA_INTERPRETER, 0, - DEVICE_INTERPRETER_XLA_JIT, options, - name_prefix, registration, - /*transfer_as_literal=*/false, &device)); + TF_RETURN_IF_ERROR(XlaDevice::Create( + "Interpreter", DEVICE_XLA_INTERPRETER, 0, DEVICE_INTERPRETER_XLA_JIT, + options, name_prefix, registration, + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 33e53612b91315349cc7ac276021150d701ccdf3..d0c7a9365125708b2af43f87c7617d8d84050a61 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -38,14 +38,13 @@ using xla::ScopedShapedBuffer; using xla::ShapedBuffer; } // anonymous namespace -std::map SnapshotResourceVariables(OpKernelContext* ctx, - int num_variables) { +std::map SnapshotResourceVariables( + OpKernelContext* ctx, const std::vector& variables) { std::map snapshot; - int first_variable = ctx->num_inputs() - num_variables; - for (int i = 0; i < num_variables; ++i) { + for (int i : variables) { Var* variable = nullptr; - ResourceHandle handle = HandleFromInput(ctx, first_variable + i); - OptionalTensor& tensor = snapshot[first_variable + i]; + ResourceHandle handle = HandleFromInput(ctx, i); + OptionalTensor& tensor = snapshot[i]; if (LookupResource(ctx, handle, &variable).ok()) { tf_shared_lock lock(*variable->mu()); tensor.name = handle.name(); @@ -61,19 +60,22 @@ XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped) XlaAllocator::~XlaAllocator() {} -xla::StatusOr XlaAllocator::Allocate( +xla::StatusOr XlaAllocator::Allocate( int device_ordinal, uint64 size, bool retry_on_failure) { - void* data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size); + AllocationAttributes attrs; + attrs.no_retry_on_failure = !retry_on_failure; + void* data = + wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs); if (data == nullptr) { return errors::ResourceExhausted("Out of memory while trying to allocate ", size, " bytes."); - } else { - return se::DeviceMemoryBase(data, size); } + return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size), + device_ordinal, this); } -Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase* mem) { - wrapped_->DeallocateRaw(mem->opaque()); +Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { + wrapped_->DeallocateRaw(mem.opaque()); return Status::OK(); } @@ -112,10 +114,9 @@ ScopedShapedBuffer ExtractSubShapedBuffer( using internal::ExtractSubShapedBuffer; XlaComputationLaunchContext::XlaComputationLaunchContext( - int64 num_resource_args, xla::LocalClient* client, - xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors) - : num_resource_args_(num_resource_args), - client_(client), + xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, + bool allocate_xla_tensors) + : client_(client), xla_allocator_(xla_allocator), allocate_xla_tensors_(allocate_xla_tensors) {} @@ -194,11 +195,6 @@ void XlaComputationLaunchContext::PopulateOutputs( OP_REQUIRES_OK( ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); - if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) { - OP_REQUIRES_OK(ctx, xla_tensor->AllocateShapedBuffer( - const_tensor.dtype(), const_tensor.shape(), - client_, stream->parent()->device_ordinal())); - } Device* device = dynamic_cast(ctx->device()); OP_REQUIRES(ctx, device != nullptr, @@ -240,7 +236,7 @@ void XlaComputationLaunchContext::PopulateOutputs( } else { Tensor output_tensor = XlaTensorBuffer::MakeTensor( ctx->expected_output_dtype(i), shape, buffer, allocator); - output.set_buffer(se::DeviceMemoryBase(nullptr, 0), {output_num}); + output.set_buffer(xla::OwningDeviceMemory(), {output_num}); ctx->set_output(i, output_tensor); } ++output_num; @@ -290,7 +286,7 @@ void XlaComputationLaunchContext::PopulateOutputs( } else { Tensor output_tensor = XlaTensorBuffer::MakeTensor( write.type, write.shape, buffer, allocator); - output.set_buffer(se::DeviceMemoryBase(nullptr, 0), {output_num}); + output.set_buffer(xla::OwningDeviceMemory(), {output_num}); *variable->tensor() = output_tensor; } ++output_num; diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 38291b0bd429b2b4f7939b2ec84213380f23d8bc..4390701ccbd0bc3971413ddcd917c11019990087 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/owning_device_memory.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" @@ -31,15 +33,17 @@ limitations under the License. namespace tensorflow { class XlaAllocator; -// Takes a snapshot of the values of resource variable arguments, which are -// the last `num_variables` arguments. We snapshot tensors that back +// Takes a snapshot of the values of resource variable arguments, whose +// indices are specified in `variables` argument. We snapshot tensors that back // resource variables since concurrent updates may modify the shape, and it is // important that the shapes used for compilation match the true shapes of the // buffers. // -// Returns a map of TensorFlow argument index to resource variable. -std::map SnapshotResourceVariables(OpKernelContext* ctx, - int num_variables); +// Returns a map of TensorFlow argument index to resource variable. If a +// resource variable is not initialized, the corresponding OptionalTensor +// will have its `present` field set to false. +std::map SnapshotResourceVariables( + OpKernelContext* ctx, const std::vector& variables); // Adapter class that wraps a Tensorflow allocator as an XLA allocator. // Assumes that the Tensorflow allocator permits asynchronous deallocation: @@ -48,9 +52,9 @@ class XlaAllocator : public xla::DeviceMemoryAllocator { public: XlaAllocator(const se::Platform* platform, Allocator* wrapped); ~XlaAllocator() override; - xla::StatusOr Allocate(int device_ordinal, uint64 size, - bool retry_on_failure) override; - Status Deallocate(int device_ordinal, se::DeviceMemoryBase* mem) override; + xla::StatusOr Allocate( + int device_ordinal, uint64 size, bool retry_on_failure) override; + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; // The Tensorflow BFC allocator used on GPU allows host-side deallocation // before GPU execution takes place. Tensorflow uses the ordering of the main @@ -72,7 +76,7 @@ class XlaComputationLaunchContext { // Create a new launch context. 'allocate_xla_tensors' is true if allocated // output tensors and variables are always XlaTensors. If false they are // assumed to be "normal" device pointers. - XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client, + XlaComputationLaunchContext(xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors); @@ -92,7 +96,6 @@ class XlaComputationLaunchContext { const std::vector& arguments() const { return arg_ptrs_; } private: - int64 num_resource_args_; xla::LocalClient* client_; xla::DeviceMemoryAllocator* xla_allocator_; bool allocate_xla_tensors_; diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc index 27813efc0bc0aecdbea2dfce5ca27ba704ea45e2..a45932403ec1760d6b985d5357fd6d84fbf257a2 100644 --- a/tensorflow/compiler/jit/xla_launch_util_test.cc +++ b/tensorflow/compiler/jit/xla_launch_util_test.cc @@ -36,9 +36,9 @@ void BM_ExtractSubBuffer(int iters, int depth, int fan_out) { for (int i = 0; i < iters; ++i) { // Extract a buffer from approximately the middle of the first level of the // tree. - tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer, - /*index=*/fan_out / 2, - /*allocator=*/nullptr) + (void)tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer, + /*index=*/fan_out / 2, + /*allocator=*/nullptr) .release(); } } diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index ce6456880bc1b3bc15ac0ef4bae35a83771098ef..a7211c9c7e281a8141d5671b345c628441b2359d 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -52,20 +52,22 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, client->backend().transfer_manager()->HostShapeToDeviceShape( on_host_shape); - xla::ShapedBuffer buffer(on_host_shape, on_device_shape, client->platform(), - device_ordinal); - for (auto& index_to_buffer : buffer.buffers()) { + xla::ScopedShapedBuffer shaped_buffer(on_host_shape, on_device_shape, + client->backend().memory_allocator(), + device_ordinal); + for (auto& index_to_buffer : shaped_buffer.buffers()) { xla::Shape subshape = xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first); uint64 size = client->backend().transfer_manager()->GetByteSizeRequirement(subshape); - TF_ASSIGN_OR_RETURN(index_to_buffer.second, + TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer, client->backend().memory_allocator()->Allocate( device_ordinal, size, /*retry_on_failure=*/false)); + // Move our buffer into shaped_buffer, which takes ownership of it. + index_to_buffer.second = buffer.Forget(); } - set_shaped_buffer(xla::ScopedShapedBuffer( - std::move(buffer), client->backend().memory_allocator())); + set_shaped_buffer(std::move(shaped_buffer)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 922a91897312096e4bb6ee2a1cc153e0039e2c7a..6b29c82ec11e39ad525663991e179443c2b6dca7 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -54,7 +54,7 @@ class XlaTensor { // Some Tensors can have complex on-device shapes, including tuple shapes. To // manage the memory for these tensors a ShapedBuffer may be required. - // Return true if this TensorInfo contains a ShapedBuffer. + // Return true if this XlaTensor contains a ShapedBuffer. bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; } // Return the contained ShapedBuffer. // REQUIRES: has_shaped_buffer() @@ -62,7 +62,7 @@ class XlaTensor { CHECK(has_shaped_buffer()); return *shaped_buffer_; } - // Mutates the TensorInfo to set the ShapedBuffer. + // Mutates the XlaTensor to set the ShapedBuffer. void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) { shaped_buffer_ = xla::MakeUnique(std::move(shaped_buffer)); @@ -72,7 +72,7 @@ class XlaTensor { // in on-demand mode to avoid re-copying values from the device if we know the // host value already. - // Return true if this TensorInfo contains a host tensor. + // Return true if this XlaTensor contains a host tensor. bool has_host_tensor() const { return host_tensor_ != nullptr; } // Return the contained host tensor. // REQUIRES: has_host_tensor() diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index a94b298f87832057c6ec86a1ea250a54ed1b4ee0..4c291d2383163e5def54657186c2190c023832fc 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -42,7 +42,7 @@ py_library( "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform", "//tensorflow/python:random_seed", "//tensorflow/python:session", @@ -58,7 +58,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -72,7 +72,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -93,7 +93,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -111,7 +111,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:bitwise_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -127,7 +127,7 @@ tf_xla_py_test( tags = ["optonly"], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], @@ -141,7 +141,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -156,7 +156,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -170,7 +170,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -184,7 +184,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:gradient_checker", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", @@ -196,9 +196,11 @@ tf_xla_py_test( name = "oom_test", size = "medium", srcs = ["oom_test.py"], + # TODO(b/80081500): Re-enable on GPU. Disabled on 2018-05-21. disabled_backends = [ "cpu", "cpu_ondemand", + "gpu", ], tags = [ # Allocates very large amounts of memory and does not work under TSAN. @@ -209,7 +211,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:gradient_checker", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", @@ -225,7 +227,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -241,7 +243,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -263,7 +265,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -291,7 +293,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -300,10 +302,14 @@ tf_xla_py_test( name = "extract_image_patches_op_test", size = "small", srcs = ["extract_image_patches_op_test.py"], + tags = [ + "manual", + "notap", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -322,8 +328,12 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", + "//tensorflow/python:layers", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", "//tensorflow/python:platform_test", + "//tensorflow/python/eager:function", ], ) @@ -338,7 +348,7 @@ tf_xla_py_test( "//tensorflow/contrib/signal:signal_py", "//tensorflow/python:array_ops", "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:spectral_ops", ], @@ -352,7 +362,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -364,7 +374,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -380,7 +390,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -395,12 +405,27 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:image_ops", "//tensorflow/python:platform_test", ], ) +tf_xla_py_test( + name = "listdiff_op_test", + size = "small", + srcs = ["listdiff_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform_test", + "@six_archive//:six", + ], +) + tf_xla_py_test( name = "lrn_ops_test", size = "medium", @@ -408,7 +433,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -423,7 +448,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -435,7 +460,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -449,7 +474,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -462,7 +487,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -475,7 +500,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -490,7 +515,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -507,7 +532,7 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], @@ -522,7 +547,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:errors", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -538,7 +563,7 @@ tf_xla_py_test( "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", "//tensorflow/python:errors", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -551,7 +576,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", ], ) @@ -563,7 +588,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -575,7 +600,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -590,7 +615,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -603,7 +628,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:platform_test", @@ -618,7 +643,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -634,7 +659,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -647,7 +672,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/contrib/stateless", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -661,7 +686,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -680,7 +705,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -693,7 +718,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -707,7 +732,7 @@ tf_xla_py_test( srcs = ["fused_batchnorm_test.py"], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn", @@ -726,7 +751,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -745,7 +770,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:training", ], @@ -760,7 +785,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -772,7 +797,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -785,21 +810,34 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) -cuda_py_test( +tf_xla_py_test( name = "xla_device_test", size = "small", srcs = ["xla_device_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( + name = "xla_device_gpu_test", + size = "small", + srcs = ["xla_device_gpu_test.py"], additional_deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", ], ) @@ -816,11 +854,22 @@ cuda_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:gradients", - "//tensorflow/python:layers", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", + ], +) + +cuda_py_test( + name = "dense_layer_test", + size = "small", + srcs = ["dense_layer_test.py"], + additional_deps = [ + "//tensorflow/contrib/compiler:compiler_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:layers", "//tensorflow/python:variables", ], ) @@ -864,7 +913,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:variables", @@ -879,7 +928,7 @@ cuda_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:gradients", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", @@ -917,7 +966,7 @@ tf_xla_py_test( srcs = ["fake_quant_ops_test.py"], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -929,7 +978,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py index ec547e16cd9c91a1e25bc963b9a3cafddf7326cd..9d3a889b1f54c813e881bb03b5275f809af1b3c8 100644 --- a/tensorflow/compiler/tests/argminmax_test.py +++ b/tensorflow/compiler/tests/argminmax_test.py @@ -29,51 +29,70 @@ from tensorflow.python.platform import test class ArgMinMaxTest(xla_test.XLATestCase): - def _assertOpOutputMatchesExpected(self, op, inp, expected): - """Verifies that 'op' produces 'expected' when fed input 'inp' . + def _assertOpOutputMatchesExpected(self, op, axis, output_type, op_input, + expected): + """Verifies that 'op' produces 'expected' when fed input 'op_input' . Args: - op: operator to test - inp: numpy input array to use as input to 'op'. + op: argmin or argmax operator to test. + axis: integer axis to reduce across. + output_type: numpy datatype of the output to produce. + op_input: numpy input array to use as input to 'op'. expected: numpy array representing the expected output of 'op'. """ with self.test_session() as session: with self.test_scope(): pinp = array_ops.placeholder( - dtypes.as_dtype(inp.dtype), inp.shape, name="a") - output = op(pinp) - result = session.run(output, {pinp: inp}) + dtypes.as_dtype(op_input.dtype), op_input.shape, name="a") + output = op(pinp, axis=axis, output_type=output_type) + result = session.run(output, {pinp: op_input}) self.assertAllEqual(result, expected) def testArgMinMax(self): # Complex numbers do not support argmin/argmax. minmax_types = set(self.numeric_types) - set(self.complex_types) for dtype in minmax_types: - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32), - np.array([1, 10, 27, 3, 3, 4], dtype=dtype), - expected=np.int32(2)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32), - np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), - expected=np.array([0, 1, 0], dtype=np.int32)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmax(x, axis=1, output_type=dtypes.int32), - np.array([[4, 1], [3, 2]], dtype=dtype), - expected=np.array([0, 0], dtype=np.int32)) + # output_type is a numpy data type that is used to specify the desired + # output type of the op as well as to convert the Python number to the + # array scalar of the type. + for output_type in self.int_types: + self._assertOpOutputMatchesExpected( + math_ops.argmax, + axis=0, + output_type=output_type, + op_input=np.array([1, 10, 27, 3, 3, 4], dtype=dtype), + expected=output_type(2)) + self._assertOpOutputMatchesExpected( + math_ops.argmax, + axis=0, + output_type=output_type, + op_input=np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), + expected=np.array([0, 1, 0], dtype=output_type)) + self._assertOpOutputMatchesExpected( + math_ops.argmax, + axis=1, + output_type=output_type, + op_input=np.array([[4, 1], [3, 2]], dtype=dtype), + expected=np.array([0, 0], dtype=output_type)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmin(x, axis=0, output_type=dtypes.int32), - np.array([3, 10, 27, 3, 2, 4], dtype=dtype), - expected=np.int32(4)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmin(x, axis=0, output_type=dtypes.int32), - np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), - expected=np.array([1, 0, 1], dtype=np.int32)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmin(x, axis=1, output_type=dtypes.int32), - np.array([[4, 1], [3, 2]], dtype=dtype), - expected=np.array([1, 1], dtype=np.int32)) + self._assertOpOutputMatchesExpected( + math_ops.argmin, + axis=0, + output_type=output_type, + op_input=np.array([3, 10, 27, 3, 2, 4], dtype=dtype), + expected=output_type(4)) + self._assertOpOutputMatchesExpected( + math_ops.argmin, + axis=0, + output_type=output_type, + op_input=np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), + expected=np.array([1, 0, 1], dtype=output_type)) + self._assertOpOutputMatchesExpected( + math_ops.argmin, + axis=1, + output_type=output_type, + op_input=np.array([[4, 1], [3, 2]], dtype=dtype), + expected=np.array([1, 1], dtype=output_type)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..865f60ccab46ec6829e49409508303052944e13b --- /dev/null +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -0,0 +1,135 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for DenseLayer JIT compilation on the CPU and GPU devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np + +from tensorflow.contrib.compiler import jit +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.layers import layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + +jit_scope = jit.experimental_jit_scope + + +def GetRunMetadataLabels(run_metadata): + """Returns all labels in run_metadata.""" + labels = [] + for dev_stats in run_metadata.step_stats.dev_stats: + for node_stats in dev_stats.node_stats: + labels.append(node_stats.timeline_label) + return labels + + +def InLabels(labels, substr): + """Returns true iff one of the labels contains substr.""" + return any([substr in x for x in labels]) + + +def XlaLaunchOpCount(labels): + """Count how many XlaLaunch labels are present.""" + return sum("XlaLaunch(" in x for x in labels) + + +class DenseLayerTest(test.TestCase): + + def testDenseLayerAutoJit(self): + """Tests dense layer compilation in auto-jit mode. + + Dense layer should be compiled into a single XlaLaunch op in auto-jit mode. + """ + + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit") + config = config_pb2.ConfigProto() + config.graph_options.optimizer_options.global_jit_level = ( + config_pb2.OptimizerOptions.ON_1) + + with self.test_session(config=config) as sess: + x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) + y = layers.dense(x, 3) + + sess.run(variables.initialize_all_variables()) + run_metadata = config_pb2.RunMetadata() + sess.run( + y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + labels = GetRunMetadataLabels(run_metadata) + self.assertEqual(1, XlaLaunchOpCount(labels)) + self.assertFalse(InLabels(labels, "ListDiff")) + + def testDenseLayerJitScopeDefinedShape(self): + """Tests that the dense layer node is properly compiled in jit scope. + + Dense layer with static shape input tensor should be compiled into a single + XlaLaunch op by XLA. + """ + + with self.test_session() as sess: + x = array_ops.placeholder(shape=[2, 2, 3], dtype=np.float32) + with jit_scope(): + y = layers.dense(x, 3) + + sess.run(variables.initialize_all_variables()) + run_metadata = config_pb2.RunMetadata() + sess.run( + y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + labels = GetRunMetadataLabels(run_metadata) + self.assertEqual(1, XlaLaunchOpCount(labels)) + # No need to check whether ListDiff is compiled or not because ListDiff op + # is not used when input tensor shape is fully defined. + + def testDenseLayerJitScopeUndefinedShape(self): + """Tests that the dense layer node is properly compiled in jit scope. + + Dense layer uses shape op to get shape of input tensor if its shape is not + fully defined. XLA does not cluster shape op with other operators. But in + experimental_jit_scope, XLA is forced to compile shape op into its own + cluster, causing dense layer to be split into TWO XlaLaunch ops. + """ + + with self.test_session() as sess: + x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) + with jit_scope(): + y = layers.dense(x, 3) + + sess.run(variables.initialize_all_variables()) + run_metadata = config_pb2.RunMetadata() + sess.run( + y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + labels = GetRunMetadataLabels(run_metadata) + self.assertEqual(2, XlaLaunchOpCount(labels)) + self.assertFalse(InLabels(labels, "ListDiff")) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index bdd0185dfe4abe9d9acecc5381ff82c54b8c0705..52d8d6d295c428f2c3466ef2963223cc978b4277 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -24,10 +24,16 @@ from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.layers import convolutional +from tensorflow.python.layers import pooling from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import googletest @@ -43,7 +49,7 @@ class EagerTest(XLATestCase): def testExecuteListOutputLen0(self): with self.test_scope(): - empty = constant_op.constant([], dtype=dtypes.int32) + empty = constant_op.constant([], dtype=dtypes.float32) result = array_ops.unstack(empty, 0) self.assertTrue(isinstance(result, list)) self.assertEqual(0, len(result)) @@ -51,7 +57,7 @@ class EagerTest(XLATestCase): def testExecuteListOutputLen1(self): with self.test_scope(): split_dim = constant_op.constant(1) - value = constant_op.constant([[0, 1, 2], [3, 4, 5]]) + value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]]) result = array_ops.split(value, 1, axis=split_dim) self.assertTrue(isinstance(result, list)) self.assertEqual(1, len(result)) @@ -60,7 +66,7 @@ class EagerTest(XLATestCase): def testExecuteListOutputLen3(self): with self.test_scope(): split_dim = constant_op.constant(1) - value = constant_op.constant([[0, 1, 2], [3, 4, 5]]) + value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]]) result = array_ops.split(value, 3, axis=split_dim) self.assertTrue(isinstance(result, list)) self.assertEqual(3, len(result)) @@ -131,7 +137,173 @@ class EagerTest(XLATestCase): self.assertEqual(2., grads[0][0].numpy()) -if __name__ == "__main__": +class EagerFunctionTest(XLATestCase): + + def testBasic(self): + with self.test_scope(): + matmul = function.defun(math_ops.matmul, compiled=True) + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + sq = matmul(t, t, transpose_a=True) + self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) + + def testConv(self): + if 'GPU' in self.device: + # TODO(b/32333178) + self.skipTest('Current implementation of RandomStandardNormal kernel ' + 'is very slow on GPU, and has been blacklisted.') + with self.test_scope(): + data_format = 'channels_last' + conv = convolutional.Conv2D( + filters=1, kernel_size=2, padding='VALID', + data_format=data_format, activation=nn_ops.relu, + kernel_initializer=init_ops.ones_initializer(), + bias_initializer=init_ops.zeros_initializer()) + pool = pooling.MaxPooling2D(2, 2, data_format=data_format) + + def model(x): + x = conv(x) + return pool(x) + model = function.defun(model, compiled=True) + + x = array_ops.ones([1, 4, 4, 1]) + y = model(x) + self.assertAllEqual(y.numpy(), [[[[4.]]]]) + + def testReadVariable(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun(compiled=True) + def f(): + return v.read_value() + + var = f() + self.assertEqual(1.0, var.numpy()) + + def testUpdateVariable(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + + def f(v): + v.assign_add(1.0) + return v + + f = function.defun(f, compiled=True) + + var = f(v) + self.assertEqual(2.0, var.numpy()) + + def testAllArgumentKinds(self): + """Test a complex function that takes different argument kinds. + + tf2xla machinery that translates, compiles, and runs defuns + classifies arguments into: compile-time constants, regular tensors, + and resources. This test creates a function with a mix of all these + kinds. Moreover, the order of function arguments is intentionally mixed up. + + This also tests the case when the same argument is a compile-time constant + as well as used in an operation that normally expects its inputs to be + in device memory - addition in this case. + """ + with self.test_scope(): + def foo(c1, r1, v1, c2, v2, r2): + # c1 and c2 are compile-time constants + # r1 and r2 are regular tensors + # v1 and v2 are resource variables + a = c1 + r1 + b = math_ops.cast(c2, dtypes.float32) + v2 + c = array_ops.slice(v1, c1, c2) + d = r2 * v2 + return a, b, c, d + + foo = function.defun(foo, compiled=True) + + c1 = [0, 0] + c2 = array_ops.ones([2], dtype=dtypes.int32) + + r1 = array_ops.ones([2]) + r2 = [[2., 2.], [3., 3.]] + + v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]]) + v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]]) + + a, b, c, d = foo(c1, r1, v1, c2, v2, r2) + + self.assertAllEqual([1, 1], a.numpy()) + self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy()) + self.assertAllEqual([[1.]], c.numpy()) + self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy()) + + def testDefunInGradientTape(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + + @function.defun(compiled=True) + def f(x): + x = v0 * v0 * x + return x + + x = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + y = f(x) + dy = tape.gradient(y, v0) + + self.assertEqual(75, y.numpy()) + self.assertEqual(30, dy.numpy()) + + +class ExcessivePaddingTest(XLATestCase): + """Test that eager execution works with TPU flattened tensors. + + Tensors that would normally be excessively padded when written + to TPU memory are reshaped to 1-D flat tensors. + + This test case verifies that such tensors work with eager execution. + + The flattening currently only happens on TPU, but tests should work + fine with all backends as flattening is transparent. + """ + + def testFromConstant(self): + with self.test_scope(): + # Create constant of shape [100, 2, 1]. This tensor would be + # excessively padded on TPU. + tensor = constant_op.constant(100 * [[[10.0], [2.0]]]) + # Use reduce_sum since it requires correctly working with + # a particular dimension. + reduced = math_ops.reduce_sum(tensor, axis=1) + self.assertAllEqual(100 * [[12.0]], reduced) + + def testFromOperation(self): + with self.test_scope(): + tensor = array_ops.ones([3, 100, 2, 2]) + reduced = math_ops.reduce_sum(tensor, axis=[0, 2, 3]) + self.assertAllEqual(100 * [12.0], reduced) + + def testAsFunctionInput(self): + with self.test_scope(): + + @function.defun(compiled=True) + def f(x): + return math_ops.reduce_sum(x, axis=2) + + tensor = constant_op.constant(100 * [[[10.0, 2.0]]]) + reduced = f(tensor) + self.assertAllEqual(100 * [[12.0]], reduced) + + def testAsFunctionOutput(self): + with self.test_scope(): + + @function.defun(compiled=True) + def f(x): + return x * constant_op.constant(100 * [[[10.0, 2.0]]]) + + y = f(3) + reduced = math_ops.reduce_sum(y, axis=2) + self.assertAllEqual(100 * [[36.0]], reduced) + + +if __name__ == '__main__': ops.enable_eager_execution( config=config_pb2.ConfigProto(log_device_placement=True)) googletest.main() diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index fbc3c994d163a504351fcccd1ba71a0997e6516f..8a3f4b0bdc7a61d6cfa2ba7474ce8579e293a5c7 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -24,12 +24,10 @@ from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -@test_util.with_c_api class FunctionTest(XLATestCase): def testFunction(self): diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 1ad83d80409734efd1f5a0a9fc39f5b7d064d54b..4b0043b6b4c7fbf57ec1507b84adf18daaea9363 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -29,13 +29,11 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.layers import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import variables from tensorflow.python.platform import test jit_scope = jit.experimental_jit_scope @@ -80,10 +78,10 @@ def InLabels(labels, substr): def MetadataHasXlaLaunch(run_metadata): - """Returns true if there is a _XlaLaunch kernel in run_metadata's timeline.""" + """Returns true if there is a XlaLaunch kernel in run_metadata's timeline.""" # TODO(phawkins): find a less hacky way to test whether a kernel ran. - return InLabels(RunMetadataLabels(run_metadata), "_XlaLaunch") + return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch") class JitLaunchTest(test.TestCase): @@ -92,8 +90,8 @@ class JitLaunchTest(test.TestCase): # Verifies that the outputs match and that XLA was invoked. 'fn' must take # the same number of tensors as arguments that are in 'args', and must return # a tuple of output tensors. - # If 'require_kernel_launch' is True, then we verify that a _XlaLaunch node - # actually ran. However, it is sometimes possible for _XlaLaunch ops to be + # If 'require_kernel_launch' is True, then we verify that a XlaLaunch node + # actually ran. However, it is sometimes possible for XlaLaunch ops to be # constant-folded away, so the check is optional. def _compare(self, fn, args, require_kernel_launch=True, noinline=None): with session_lib.Session(config=NoRewriteSessionConfig()) as sess: @@ -443,31 +441,14 @@ class XlaCompilationTest(test.TestCase): self.assertFalse(InLabels(labels, "Log")) self.assertTrue(InLabels(labels, "Reciprocal")) self.assertTrue(InLabels(labels, "Mul")) - self.assertFalse(InLabels(labels, "_XlaLaunch")) + self.assertFalse(InLabels(labels, "XlaLaunch")) - # Compile the backprop. One _XlaLaunch. + # Compile the backprop. One XlaLaunch. labels = _Run(compiled=True) self.assertFalse(InLabels(labels, "Log")) self.assertFalse(InLabels(labels, "Reciprocal")) self.assertFalse(InLabels(labels, "Mul")) - self.assertTrue(InLabels(labels, "_XlaLaunch")) - - def testDenseLayer(self): - """Tests that the dense layer node is properly compiled.""" - - with self.test_session(config=NoRewriteSessionConfig()) as sess: - x = array_ops.placeholder(shape=[2, 3], dtype=np.float32) - with jit_scope(): - y = layers.dense(x, 3) - - sess.run(variables.initialize_all_variables()) - run_metadata = config_pb2.RunMetadata() - sess.run(y, {x: np.array([[1, 2, 3], [4, 5, 6]])}, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) - - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assertTrue(InLabels(labels, "XlaLaunch")) class ElementWiseFusionTest(test.TestCase): @@ -501,7 +482,7 @@ class ElementWiseFusionTest(test.TestCase): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = RunMetadataLabels(run_metadata) - count = sum("_XlaLaunch(" in x for x in labels) + count = sum("XlaLaunch(" in x for x in labels) return output, count diff --git a/tensorflow/compiler/tests/listdiff_op_test.py b/tensorflow/compiler/tests/listdiff_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..45a04f0cf56e88946b946bedacb25ce6da3121b4 --- /dev/null +++ b/tensorflow/compiler/tests/listdiff_op_test.py @@ -0,0 +1,101 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA listdiff operator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ListDiffTest(xla_test.XLATestCase): + + def _testListDiff(self, x, y, out, idx): + for dtype in [dtypes.int32, dtypes.int64]: + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.test_session() as sess: + x_tensor = ops.convert_to_tensor(x, dtype=dtype) + y_tensor = ops.convert_to_tensor(y, dtype=dtype) + with self.test_scope(): + out_tensor, idx_tensor = array_ops.listdiff( + x_tensor, y_tensor, out_idx=index_dtype) + tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) + self.assertAllEqual(out, tf_out) + self.assertAllEqual(idx, tf_idx) + self.assertEqual(1, out_tensor.get_shape().ndims) + self.assertEqual(1, idx_tensor.get_shape().ndims) + + def testBasic1(self): + self._testListDiff(x=[1, 2, 3, 4], y=[1, 2], out=[3, 4], idx=[2, 3]) + + def testBasic2(self): + self._testListDiff(x=[1, 2, 3, 4], y=[2], out=[1, 3, 4], idx=[0, 2, 3]) + + def testBasic3(self): + self._testListDiff(x=[1, 4, 3, 2], y=[4, 2], out=[1, 3], idx=[0, 2]) + + def testDuplicates(self): + self._testListDiff(x=[1, 2, 4, 3, 2, 3, 3, 1], + y=[4, 2], + out=[1, 3, 3, 3, 1], + idx=[0, 3, 5, 6, 7]) + + def testRandom(self): + num_random_tests = 10 + int_low = -7 + int_high = 8 + max_size = 50 + for _ in xrange(num_random_tests): + x_size = np.random.randint(max_size + 1) + x = np.random.randint(int_low, int_high, size=x_size) + y_size = np.random.randint(max_size + 1) + y = np.random.randint(int_low, int_high, size=y_size) + out_idx = [(entry, pos) for pos, entry in enumerate(x) if entry not in y] + if out_idx: + out, idx = map(list, zip(*out_idx)) + else: + out = [] + idx = [] + self._testListDiff(list(x), list(y), out, idx) + + def testFullyOverlapping(self): + self._testListDiff(x=[1, 2, 3, 4], y=[1, 2, 3, 4], out=[], idx=[]) + + def testNonOverlapping(self): + self._testListDiff(x=[1, 2, 3, 4], + y=[5, 6], + out=[1, 2, 3, 4], + idx=[0, 1, 2, 3]) + + def testEmptyX(self): + self._testListDiff(x=[], y=[1, 2], out=[], idx=[]) + + def testEmptyY(self): + self._testListDiff(x=[1, 2, 3, 4], y=[], out=[1, 2, 3, 4], idx=[0, 1, 2, 3]) + + def testEmptyXY(self): + self._testListDiff(x=[], y=[], out=[], idx=[]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index e53efc3091d8935e745122af29abd7b8063b1d01..16f293891d56d78885dd515bb7b9899faf0690f7 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -619,8 +619,8 @@ std::vector OpTest::ImageDims(TensorFormat format, int batch, dims.push_back(dim); } break; - case FORMAT_NCHW_VECT_C: - LOG(FATAL) << "FORMAT_NCHW_VECT_C not supported."; + default: + LOG(FATAL) << "Tensor format " << ToString(format) << " not supported."; } return dims; } diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 4336ebdbd184a081619f0a6951dd4514735c6eb6..b6f8390a45d43bf7666b90e14cc6ff2f3f61947e 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -86,6 +86,15 @@ class StatelessRandomOpsTest(XLATestCase): # seed were not fixed. self.assertTrue(self._chi_squared(y, 10) < 16.92) + def testRandomNormalIsFinite(self): + with self.test_session() as sess, self.test_scope(): + for dtype in self._random_types(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + x = stateless.stateless_random_uniform( + shape=[10000], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + self.assertTrue(np.all(np.isfinite(y))) + def _normal_cdf(self, x): """Cumulative distribution function for a standard normal distribution.""" return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2)) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index ba79f393a8f9b24ac506d2130957c38ecd442509..689a4a1f4e02f5dd48f64dc94afd0fcb50df8b5b 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -209,7 +209,8 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.expm1, np.array([[-1, 1]], dtype=dtype), - expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype)) + expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype), + rtol=1e-5) self._assertOpOutputMatchesExpected( math_ops.floor, @@ -251,12 +252,12 @@ class UnaryOpsTest(XLATestCase): np.array([[1, 2]], dtype=dtype), expected=np.array([[0.540297, -0.41614]], dtype=dtype)) - # TODO(b/34703906): improve log1p implementation and make tolerance - # tighter. self._assertOpOutputMatchesExpected( math_ops.log1p, np.array([[1e-14, 1e-15, 0.6]], dtype=dtype), - expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype))) + expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)), + rtol=1e-4, + atol=1e-6) self._assertOpOutputMatchesExpected( math_ops.rint, @@ -333,13 +334,19 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( nn_ops.elu, - np.array([[-1, 0, 1]], dtype=dtype), - expected=np.array([[-0.63212056, 0, 1]], dtype=dtype)) + np.array([[-1, 0, 1, -1e-6]], dtype=dtype), + expected=np.array([[-0.63212056, 0, 1, -9.999995e-07]], dtype=dtype), + rtol=1e-5, + atol=1e-6) self._assertOpOutputMatchesExpected( nn_ops.selu, - np.array([[-1, 0, 1]], dtype=dtype), - expected=np.array([[-1.11133074, 0., 1.05070099]], dtype=dtype)) + np.array([[-1, 0, 1, -1e-5]], dtype=dtype), + expected=np.array( + [[-1.11133074, 0., 1.05070099, -1.758090550379974e-05]], + dtype=dtype), + rtol=1e-5, + atol=1e-6) self._assertOpOutputMatchesExpected( nn_ops.relu, @@ -419,7 +426,9 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.expm1, np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype), - expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype))) + expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype)), + rtol=1e-6, + atol=1e-6) self._assertOpOutputMatchesExpected( math_ops.reciprocal, @@ -441,13 +450,13 @@ class UnaryOpsTest(XLATestCase): np.array([[5j, 3 - 2j]], dtype=dtype), expected=np.cos(np.array([[5j, 3 - 2j]], dtype=dtype))) - # TODO(b/34703906): improve log1p implementation and make tolerance - # tighter. self._assertOpOutputMatchesExpected( math_ops.log1p, np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype), expected=np.log1p( - np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype))) + np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype)), + rtol=1e-4, + atol=1e-6) val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype) self._assertOpOutputMatchesExpected( @@ -789,7 +798,9 @@ class UnaryOpsTest(XLATestCase): zero = np.asarray(0).astype(dtype) expected = np.logaddexp(zero, features) self._assertOpOutputMatchesExpected( - nn_ops.softplus, features, expected=expected) + nn_ops.softplus, features, expected=expected, + rtol=1e-6, + atol=9.1e-6) def testSoftplus(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/xla_device_gpu_test.py b/tensorflow/compiler/tests/xla_device_gpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1e30ebd55d09fe00449fb67b92a8325f5809d89a --- /dev/null +++ b/tensorflow/compiler/tests/xla_device_gpu_test.py @@ -0,0 +1,48 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for XLA devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.client import session as session_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class XlaDeviceGpuTest(test.TestCase): + + def testCopiesToAndFromGpuWork(self): + """Tests that copies between GPU and XLA devices work.""" + if not test.is_gpu_available(): + return + + with session_lib.Session() as sess: + x = array_ops.placeholder(dtypes.float32, [2]) + with ops.device("GPU"): + y = x * 2 + with ops.device("device:XLA_CPU:0"): + z = y * y + with ops.device("GPU"): + w = y + z + result = sess.run(w, {x: [1.5, 0.5]}) + self.assertAllClose(result, [12., 2.], rtol=1e-3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index f5c228f8305d740b994dadc34c93b4e0ae32d785..b707bd0963d71d7c4b43b8d42752b4c50e9bbf7c 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,30 +18,33 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.client import session as session_lib -from tensorflow.python.framework import dtypes +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class XlaDeviceTest(test.TestCase): +class XlaDeviceTest(XLATestCase): def testCopies(self): - """Tests that copies between GPU and XLA devices work.""" - if not test.is_gpu_available(): - return - - with session_lib.Session() as sess: - x = array_ops.placeholder(dtypes.float32, [2]) - with ops.device("GPU"): - y = x * 2 - with ops.device("device:XLA_CPU:0"): - z = y * y - with ops.device("GPU"): - w = y + z - result = sess.run(w, {x: [1.5, 0.5]}) - self.assertAllClose(result, [12., 2.], rtol=1e-3) + """Tests that copies onto and off XLA devices work.""" + shapes = [[0], [1], [1, 0], [1024, 0], [1024, 1], [3, 777], [777, 3], + [16384, 1], [1, 16384], [1, 20000, 1, 1]] + for dtype in self.numeric_types: + for shape in shapes: + with self.test_session() as sess: + with ops.device("CPU"): + x = array_ops.placeholder(dtype, shape) + with self.test_scope(): + y = x + x + with ops.device("CPU"): + z = array_ops.identity(y) + + inputs = np.random.randint(-100, 100, shape).astype(dtype) + result = sess.run(z, {x: inputs}) + self.assertAllCloseAccordingToType(result, inputs + inputs) if __name__ == "__main__": diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 4fca51f54d320e843343f80d7df1177f80f1d99f..cd57452302fcbde37d79ce760a80615a76d7ad8c 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -325,6 +325,7 @@ tf_cc_test( "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:cpu_plugin", diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index 4f8bb8ad743afe69a6544c2ae0dc7309891b2df3..ea8d1b3d14939d4f4fba598318200f71c2eb0270 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -27,3 +27,25 @@ cc_library( "//tensorflow/core:protos_all_cc", ], ) + +tf_gen_op_wrapper_cc( + name = "xla_jit_op_gen", + out_ops_file = "ops/xla_jit_op", + deps = ["//tensorflow/compiler/jit/ops:xla_ops"], +) + +cc_library( + name = "xla_jit_ops", + srcs = ["ops/xla_jit_op.cc"], + hdrs = ["ops/xla_jit_op.h"], + deps = [ + "//tensorflow/cc:const_op", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/compiler/jit/ops:xla_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 8d1f2684909e876fe5521ba6a63d745c7d3956e0..42585ad4d8a17d71146e48b69f9fa56f9ff24c3e 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -282,7 +282,58 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, return Status::OK(); } -Status FunctionalizeLoop(Graph* graph, Frame* frame, +// Copy the FunctionDef of given function from lookup_library to library, if +// it can be found in lookup_library but is missing from library. +Status AddMissingFunctionByName(const string& function_name, + const FunctionLibraryDefinition* lookup_library, + FunctionLibraryDefinition* library) { + if (!library->Find(function_name) && lookup_library->Find(function_name)) { + return library->AddFunctionDef(*lookup_library->Find(function_name)); + } + return Status::OK(); +} + +// Iterate over all functions that the given fdef refers to. Copy the missing +// FunctionDefs from lookup_library to library. +Status AddMissingFunctionDef(const FunctionDef& fdef, + const FunctionLibraryDefinition* lookup_library, + FunctionLibraryDefinition* library) { + TF_RET_CHECK(lookup_library); + for (const NodeDef& node : fdef.node_def()) { + if (library->Find(node.op())) { + continue; + } + // The function refered by 'SymbolicGradient' node is specified in its + // attribute 'f'. + if (node.op() == FunctionLibraryDefinition::kGradientOp) { + const AttrValue* attr = + AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr); + if (!attr) { + return errors::InvalidArgument("SymbolicGradient is missing attr: f"); + } + const string& func_name = attr->func().name(); + TF_RETURN_IF_ERROR( + AddMissingFunctionByName(func_name, lookup_library, library)); + // Copy the user-defined gradient function if it exists. + const string grad_name = lookup_library->FindGradient(func_name); + if (!grad_name.empty() && library->FindGradient(func_name).empty()) { + TF_RETURN_IF_ERROR( + AddMissingFunctionByName(grad_name, lookup_library, library)); + GradientDef grad_def; + grad_def.set_function_name(func_name); + grad_def.set_gradient_func(grad_name); + TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def)); + } + } else if (lookup_library->Find(node.op())) { + TF_RETURN_IF_ERROR( + library->AddFunctionDef(*lookup_library->Find(node.op()))); + } + } + return Status::OK(); +} + +Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, + Graph* graph, Frame* frame, FunctionLibraryDefinition* library) { VLOG(2) << "Frame " << frame->name << " before: " << dump_graph::DumpGraphToFile("functionalize_before", *graph, @@ -489,6 +540,14 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef)); TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); + if (lookup_library) { + // Copy missing FunctionDefs from lookup_library to library to make library + // self-contained. + TF_RETURN_IF_ERROR( + AddMissingFunctionDef(cond_fdef, lookup_library, library)); + TF_RETURN_IF_ERROR( + AddMissingFunctionDef(body_fdef, lookup_library, library)); + } // Builds a While operator. NodeDef while_def; @@ -1365,6 +1424,12 @@ Status FunctionalizeCond::Functionalize(Graph* graph, // functional equivalents. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library) { + return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); +} + +Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, + Graph* graph, + FunctionLibraryDefinition* library) { VLOG(2) << "FunctionalizeControlFlow (initial): " << dump_graph::DumpGraphToFile("functionalize_initial", *graph, library); @@ -1434,7 +1499,8 @@ Status FunctionalizeControlFlow(Graph* graph, continue; } - TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library)); + TF_RETURN_IF_ERROR( + FunctionalizeLoop(lookup_library, graph, frame, library)); // If the parent has no remaining children, add it to the worklist. --frame->parent->num_children; diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index 4d4ee3054c2914bb614bf75f7a51be8f6292683e..d941041d15532446d1413f16fe64602bfb1a7daa 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -22,9 +22,13 @@ limitations under the License. namespace tensorflow { // Transformation that converts tf.while_loop() loops into functional While -// operators, suitable for XLA compilation. +// operators, suitable for XLA compilation. If lookup_library is provided, use +// it to make the library for control flow self-contained. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library); +Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, + Graph* graph, + FunctionLibraryDefinition* library); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index e494f42e8ed254ac0c7c7a23a13728d3f015e9d3..14977a908ae2b0ff7e13b634c41b6d331b4b8a36 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -299,6 +299,131 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { } } +// @function.Defun(noinline=True) +// def increment_fn(x): +// return [x + 1] +// Define the above function, and add it to the given graph. It's used as the +// while loop body in NoinlineLoopBody test. +Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { + FunctionDef fdef = FunctionDefHelper::Create( + "increment_fn", {"x:int32"}, {"add:int32"}, {}, + { + {{"add/y"}, "Const", {}, {{"dtype", DT_INT32}}}, + {{"add_0"}, "Add", {"x", "add/y:output:0"}, {{"T", DT_INT32}}}, + }, + {{"add", "add_0:z:0"}}); + (*fdef.mutable_attr())["_noinline"].set_b(true); + FunctionDefLibrary fdef_lib; + *(fdef_lib.add_function()) = fdef; + TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib)); + NodeDef increment_fn; + increment_fn.set_name(node_name); + increment_fn.set_op("increment_fn"); + *increment_fn.add_input() = "while/Identity"; + *increment_fn.add_input() = "^while/Identity"; + Status status; + graph->AddNode(increment_fn, &status); + return status; +} + +// Graph: +// x = array_ops.placeholder(dtypes.int32) +// y = control_flow_ops.while_loop(lambda i: i < 10, increment_fn, [x]) +TEST(FunctionalizeControlFlow, NoinlineLoopBody) { + const string& noinline_node_name = "while/increment_fn"; + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto enter = ops::internal::Enter(scope.WithOpName("while/Enter"), source, + "while/while_context"); + auto merge = ops::Merge(scope.WithOpName("while/Merge"), + std::initializer_list{enter, dummy}); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), + 10); + auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); + auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); + auto switch_ = + ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); + auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"), + switch_.output_false); + auto identity = + ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); + + TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); + + NodeDef next_iter; + next_iter.set_name("while/NextIteration"); + next_iter.set_op("NextIteration"); + *next_iter.add_input() = noinline_node_name; + (*next_iter.mutable_attr())["T"].set_type(DT_INT32); + + Status status; + Node* n = scope.graph()->AddNode(next_iter, &status); + TF_ASSERT_OK(status); + + // Remove the dummy node and add the loop backedge. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(n, 0, merge.output.node(), 1); + TF_ASSERT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition lookup_lib(graph.flib_def()); + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + // Function increment_fn will be copied from lookup_lib to library. + TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + NameAttrList cond_fn, body_fn; + TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::XlaWhile(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + NodeDef retval; + retval.set_name("_retval0_RetVal"); + retval.set_op(FunctionLibraryDefinition::kRetOp); + *retval.add_input() = noinline_node_name; + (*retval.mutable_attr())["T"].set_type(DT_INT32); + (*retval.mutable_attr())["index"].set_i(0); + Status status; + scope.graph()->AddNode(retval, &status); + TF_ASSERT_OK(status); + + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + // Verify that increment_fn has been copied to library. + TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + // Ignore the function library when comparing the graphs. + expected.clear_library(); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + // Tests functionalizing OneLoopVar where the loop value is not used post the // loop. // Graph: diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 8115a26210a8e9e95e851f350e34dcdfa2519a64..212f6f3966149ca0b2d2e012b19300e1f488f996 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -208,10 +208,11 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, TF_RETURN_IF_ERROR( PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments)); + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = false; XlaCompiler::CompilationResult result; - - TF_RETURN_IF_ERROR(compiler->CompileFunction(XlaCompiler::CompileOptions(), - func, arguments, &result)); + TF_RETURN_IF_ERROR( + compiler->CompileFunction(compile_options, func, arguments, &result)); TF_RET_CHECK(arguments.size() == expressions.size()); @@ -229,11 +230,14 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, auto output_handle = b->Call(*result.computation, handles); // The output handle of `Call` computation is a tuple type. Unzip it so // that it can fit into future computations. + int computation_output = 0; for (int64 i = 0; i < n->num_outputs(); ++i) { if (result.outputs[i].is_constant) { xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value); } else { - xla_op_context.SetOutput(i, b->GetTupleElement(output_handle, i)); + xla_op_context.SetOutput( + i, b->GetTupleElement(output_handle, computation_output)); + ++computation_output; } } return b->first_error(); diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 85ab4c41bf6a754236066260819f103970e603ae..e6da157c111ad9167bf7b1e743d9afbb8fb2ad03 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -45,6 +45,7 @@ tf_kernel_library( "image_resize_ops.cc", "index_ops.cc", "l2loss_op.cc", + "listdiff_op.cc", "lrn_ops.cc", "matmul_op.cc", "matrix_band_part_op.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index ed7462c16615f7f63a174e29843c2a1675c17058..493781a1e68b8906f1a7e018e5710130e2eb08b5 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -34,9 +34,8 @@ class EluOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto one = XlaHelpers::One(b, input_type(0)); const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); + const auto expm1 = b->Expm1(ctx->Input(0)); ctx->SetOutput(0, b->Select(pred, ctx->Input(0), expm1)); } }; @@ -68,13 +67,12 @@ class SeluOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto one = XlaHelpers::One(b, input_type(0)); const auto scale = XlaHelpers::FloatLiteral(b, input_type(0), 1.0507009873554804934193349852946); const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0), 1.7580993408473768599402175208123); const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); + const auto expm1 = b->Expm1(ctx->Input(0)); ctx->SetOutput(0, b->Select(pred, b->Mul(scale, ctx->Input(0)), b->Mul(scale_alpha, expm1))); } diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0388b4c830702ea00ec69fc42c6468326c88cf38 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific ListDiff Op. This only supports constant DT_INT32 and DT_INT64 +// input. + +#include + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { + +constexpr std::array kListDiffTypes = {DT_INT32, DT_INT64}; + +// ListDiffOp is an XLA kernel that supports constant-only x and y input. +class ListDiffOp : public XlaOpKernel { + public: + explicit ListDiffOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + OP_REQUIRES(context, TensorShapeUtils::IsVector(context->InputShape(0)), + errors::InvalidArgument("ListDiff expects x as a vector, not ", + context->InputShape(0).DebugString())); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(context->InputShape(1)), + errors::InvalidArgument("ListDiff expects y as a vector, not ", + context->InputShape(1).DebugString())); + + DataType val_type = context->expected_output_dtype(0); + DataType idx_type = context->expected_output_dtype(1); + + Status status; + switch (val_type) { + case DT_INT32: + status = ListDiffWithIndexType(context, idx_type); + break; + case DT_INT64: + status = ListDiffWithIndexType(context, idx_type); + break; + default: + // This should never happen since we restrict this kernel to only match + // inputs with supported Tensor datatype. + status = errors::InvalidArgument("ListDiff expects x and y as either ", + "int32 or int64, not ", + DataTypeString(val_type)); + } + OP_REQUIRES_OK(context, status); + } + + private: + template + Status ListDiff(XlaOpKernelContext* context) { + std::vector x_input, y_input; + TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(0, &x_input)); + TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(1, &y_input)); + + std::unordered_set y_input_set; + y_input_set.reserve(y_input.size()); + for (auto y : y_input) { + y_input_set.insert(y); + } + + std::vector val_output; + std::vector idx_output; + auto x_size = x_input.size(); + for (Tidx i = 0; i < x_size; ++i) { + if (y_input_set.count(x_input[i]) > 0) { + continue; + } + val_output.push_back(x_input[i]); + idx_output.push_back(i); + } + + context->SetOutput(0, context->builder()->ConstantR1(val_output)); + context->SetOutput(1, context->builder()->ConstantR1(idx_output)); + return Status::OK(); + } + + template + Status ListDiffWithIndexType(XlaOpKernelContext* context, DataType idx_type) { + switch (idx_type) { + case DT_INT32: + return ListDiff(context); + case DT_INT64: + return ListDiff(context); + default: + return errors::InvalidArgument( + "ListDiff expects idx_out as either int32 or int64, not ", + DataTypeString(idx_type)); + } + } +}; + +REGISTER_XLA_OP(Name("ListDiff") + .TypeConstraint("T", kListDiffTypes) + .CompileTimeConstInput("x") + .CompileTimeConstInput("y"), + ListDiffOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index 70547290eaed169599764a5d66185dde85345863..a711278638444be01fb865561957702368b75114 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -55,18 +55,33 @@ class RetvalOp : public XlaOpKernel { } XlaContext& tc = XlaContext::Get(ctx); - if (input_shape.num_elements() == 0 || is_constant.ValueOrDie()) { + if (tc.resolve_compile_time_constants() && + (input_shape.num_elements() == 0 || is_constant.ValueOrDie())) { xla::Literal literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal)); OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); } else { - // The core from which a return value is returned depends on the core - // assignment of the input to the retval .Since we can't change the core - // assignment of as this point, create a tuple/get-tuple-element - // combination so that the core will be set on them. - auto tuple_elem = - ctx->builder()->GetTupleElement(ctx->builder()->Tuple({input}), 0); - tc.AddRetval(index_, dtype_, tuple_elem); + TensorShape shape = ctx->InputShape(0); + TensorShape representation_shape = + tc.is_entry_computation() + ? tc.RepresentationShape(shape, ctx->input_type(0)) + : shape; + + xla::XlaOp output = input; + if (tc.is_entry_computation()) { + output = + ctx->builder()->Reshape(input, representation_shape.dim_sizes()); + } else { + // The core from which a return value is returned depends on the + // device assignment of the input to the retval. Since we can't change + // the device assignment of "input" at this point, we must always + // introduce an operator here, even if the shape does not change. + // TODO(b/76097077): propagate device assignments onto arguments and + // return values of functions, and then reshape unconditionally. + output = ctx->builder()->GetTupleElement( + ctx->builder()->Tuple({output}), 0); + } + tc.AddRetval(index_, dtype_, shape, output); } } } diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 0ed4c4707df71cf5f56ccfe0af506916f04bcdb5..5d1c05268493f4f6404c40a4092a71f1e5b3f3b9 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -106,20 +106,40 @@ class ReverseSequenceOp : public XlaOpKernel { seq_lens, body_builder->Reshape(i, {1}), {1}); // Indices is the offset of the batch element in the input. - auto indices = body_builder->Broadcast( + auto batch_element_indices = body_builder->Broadcast( XlaHelpers::Zero(body_builder.get(), seq_lens_type), {input_shape.dims()}); - indices = body_builder->DynamicUpdateSlice( - indices, body_builder->Reshape(i, {1}), + batch_element_indices = body_builder->DynamicUpdateSlice( + batch_element_indices, body_builder->Reshape(i, {1}), body_builder->Reshape( XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, batch_dim_), {1})); - // slice_indices is the offset of the start of the reversed sequence in - // the input. - auto slice_indices = body_builder->DynamicUpdateSlice( - indices, + // Slice out the current batch element and pad it out in the sequence + // dimension. + TensorShape slice_shape = input_shape; + slice_shape.set_dim(batch_dim_, 1); + slice_shape.set_dim(seq_dim_, max_seq_len); + auto slice = body_builder->DynamicSlice(output, batch_element_indices, + slice_shape.dim_sizes()); + auto padding_config = xla::MakeNoPaddingConfig(slice_shape.dims()); + padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high( + slice_shape.dim_size(seq_dim_)); + slice = body_builder->Pad( + slice, XlaHelpers::Zero(body_builder.get(), input_type), + padding_config); + + // Now slice out the reversed sequence from its actual start. + // sequence_start_indices is the offset of the start of the reversed + // sequence in the input. The slice will go into the padding, however, we + // will mask off these elements and replace them with elements from the + // original input so their values do not matter. + auto sequence_start_indices = body_builder->Broadcast( + XlaHelpers::Zero(body_builder.get(), seq_lens_type), + {slice_shape.dims()}); + sequence_start_indices = body_builder->DynamicUpdateSlice( + sequence_start_indices, body_builder->Sub(XlaHelpers::IntegerLiteral( body_builder.get(), seq_lens_type, max_seq_len), seq_len), @@ -127,18 +147,12 @@ class ReverseSequenceOp : public XlaOpKernel { XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, seq_dim_), {1})); - - // Slice out the reversed sequence. The slice will overflow the end of the - // sequence, and the contents of the overflow are implementation-defined. - // However, we will mask off these elements and replace them with elements - // from the original input so their values do not matter. - TensorShape slice_shape = input_shape; - slice_shape.set_dim(batch_dim_, 1); - auto slice = body_builder->DynamicSlice(output, slice_indices, - slice_shape.dim_sizes()); + slice = body_builder->DynamicSlice(slice, sequence_start_indices, + slice_shape.dim_sizes()); // Shift the reversed sequence to the left. - output = body_builder->DynamicUpdateSlice(output, slice, indices); + output = body_builder->DynamicUpdateSlice(output, slice, + batch_element_indices); body_builder->Tuple( {body_builder->Add( diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 6340c225185e68df638747def5b4fda3ef4c28ac..a99d4ddc7c4956f7144512a9bdf6f4c2eb0f944f 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -255,7 +255,8 @@ class StatelessRandomNormalOp : public XlaOpKernel { seed_shape.DebugString())); xla::XlaOp seed = ctx->Input(1); xla::XlaBuilder* builder = ctx->builder(); - auto uniform = RandomUniform(builder, seed, shape, -1.0, 1.0); + auto uniform = + RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0); // Convert uniform distribution to normal distribution by computing // sqrt(2) * erfinv(x) auto normal = builder->Mul(builder->ConstantR0(std::sqrt(2.0)), diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index a4f50f52ebe8b1ed7df862996d64e135ea1d0ac5..71a9fd051bfc8db09738a4bfe8ddde447895ecf0 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -100,8 +100,7 @@ XLAJIT_MAKE_UNARY(Cosh, XLAJIT_MAKE_UNARY(Sin, b->Sin(x)); XLAJIT_MAKE_UNARY(Exp, b->Exp(x)); -// TODO(b/34703906): use a more accurate implementation of expm1. -XLAJIT_MAKE_UNARY(Expm1, b->Sub(b->Exp(x), XlaHelpers::One(b, input_type(0)))); +XLAJIT_MAKE_UNARY(Expm1, b->Expm1(x)); XLAJIT_MAKE_UNARY(Floor, b->Floor(x)); XLAJIT_MAKE_UNARY(IsFinite, b->IsFinite(x)); @@ -115,8 +114,7 @@ XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x)); XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x)); XLAJIT_MAKE_UNARY(Log, b->Log(x)); -// TODO(b/34703906): use a more accurate implementation of log1p. -XLAJIT_MAKE_UNARY(Log1p, b->Log(b->Add(XlaHelpers::One(b, input_type(0)), x))); +XLAJIT_MAKE_UNARY(Log1p, b->Log1p(x)); XLAJIT_MAKE_UNARY(Invert, b->Not(x)); XLAJIT_MAKE_UNARY(LogicalNot, b->Not(x)); @@ -160,24 +158,17 @@ XLAJIT_MAKE_UNARY(Sinh, b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))), XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -static xla::XlaOp Softplus(xla::XlaBuilder* b, DataType dtype, - const xla::XlaOp& features) { - xla::XlaOp threshold = b->Add(b->Log(XlaHelpers::Epsilon(b, dtype)), - XlaHelpers::FloatLiteral(b, dtype, 2.0)); - // Value above which exp(x) may overflow, but softplus(x) == x - // is within machine epsilon. - xla::XlaOp too_large = b->Gt(features, b->Neg(threshold)); - // Value below which exp(x) may underflow, but softplus(x) == exp(x) - // is within machine epsilon. - xla::XlaOp too_small = b->Lt(features, threshold); - xla::XlaOp features_exp = b->Exp(features); - xla::XlaOp output = b->Select( - too_large, features, - b->Select(too_small, features_exp, - b->Log(b->Add(features_exp, XlaHelpers::One(b, dtype))))); - return output; -} -XLAJIT_MAKE_UNARY(Softplus, Softplus(b, input_type(0), x)); +// softplus(x) = log(1 + exp(x)) +// +// This is not numerically stable when x is large, it can easily overflow. +// However, we can compute it as LogSumExp(x, 0): +// max(x, 0) + log(exp(x - max(x, 0)) + exp(0 - max(x, 0))) +// +// This is equivalent to: +// max(x, 0) + log1p(exp(-abs(x))) +XLAJIT_MAKE_UNARY(Softplus, + b->Add(b->Max(x, XlaHelpers::Zero(b, input_type(0))), + b->Log1p(b->Exp(b->Neg(b->Abs(x)))))); // softsign(x) = x / (abs(x) + 1) XLAJIT_MAKE_UNARY(Softsign, diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 04ad3694a0c0df9d43c706d428c3b8715e5ff8ca..ee7f5d510ab7a3ce7d3bbe843c5fefd362f79b7b 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -80,7 +80,6 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -141,7 +140,6 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/tests:client_library_test_base", diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 83e73827862ca26a1a51bed72ab87768854c1e71..3f1384bc864abd882ebba2b90acbe0b1e664687a 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -214,7 +214,7 @@ xla::StatusOr Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, /*lower=*/true, /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/8)); + /*block_size=*/block_size)); TF_ASSIGN_OR_RETURN( l, UpdateSliceInMinorDims(builder, l, update, {i + k, i})); } diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 2c3cd658e0462368ac0b51938979b7a6815a7574..43e1c1e9fecec1c71db1509757251cb5d903ca49 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -40,7 +40,7 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { return Status::OK(); } -Status CopyLiteralToHostTensor(const xla::Literal& literal, +Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, Tensor* host_tensor) { TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) && xla::ShapeUtil::ElementsIn(literal.shape()) == @@ -63,8 +63,8 @@ Status CopyLiteralToHostTensor(const xla::Literal& literal, return Status::OK(); } -Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, - Tensor* host_tensor) { +Status LiteralToHostTensor(const xla::LiteralSlice& literal, + DataType target_type, Tensor* host_tensor) { TensorShape shape; TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape)); *host_tensor = Tensor(target_type, shape); diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index f283b0236811f8d52e8fe2982a74c11c92cd20d8..220bec15538c36fa30abef9e729b64dbbb9f72b3 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -36,13 +36,13 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); // derivable from the type of , because multiple tensorflow types map // to the same XLA type (e.g. INT32 and QINT32 both map to INT32 in // XLA). -Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, - Tensor* host_tensor); +Status LiteralToHostTensor(const xla::LiteralSlice& literal, + DataType target_type, Tensor* host_tensor); // Copies the contents of 'literal' to a previously allocated tensor // 'host_tensor'. The tensor and the literal must have the same number of // elements and the same type. -Status CopyLiteralToHostTensor(const xla::Literal& literal, +Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, Tensor* host_tensor); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 3d1946c332b0f903b710a19fbb79fc9923e89c43..f7098917b191058c53a1d6a5923e80e5e8319d72 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -15,10 +15,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include #include +#include -#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" @@ -28,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" @@ -40,7 +38,6 @@ limitations under the License. #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/public/version.h" namespace tensorflow { namespace { @@ -110,10 +107,10 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) local_flib_runtime_ = local_pflr_->GetFLR(device_->name()); flib_runtime_ = pflr_->GetFLR(device_->name()); - // The default variable representation shape is the identity function. - if (!options_.variable_representation_shape_fn) { - options_.variable_representation_shape_fn = - [](const TensorShape& shape, DataType type) { return shape; }; + // The default shape representation function is the identity. + if (!options_.shape_representation_fn) { + options_.shape_representation_fn = [](const TensorShape& shape, + DataType type) { return shape; }; } } @@ -230,20 +227,25 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, // Computes the XLA shape for argument 'arg'. Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, + bool is_entry_computation, xla::Shape* xla_shape) { switch (arg.kind) { case XlaCompiler::Argument::kConstant: - return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(), - xla_shape); - case XlaCompiler::Argument::kParameter: - return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape); + LOG(FATAL) << "Unreachable case"; + case XlaCompiler::Argument::kParameter: { + TensorShape shape = + is_entry_computation + ? options_.shape_representation_fn(arg.shape, arg.type) + : arg.shape; + return TensorShapeToXLAShape(arg.type, shape, xla_shape); + } case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); switch (arg.resource_kind) { case XlaResource::kVariable: { TensorShape representation_shape = - options_.variable_representation_shape_fn(arg.shape, arg.type); + options_.shape_representation_fn(arg.shape, arg.type); return TensorShapeToXLAShape(arg.type, representation_shape, xla_shape); } @@ -337,16 +339,25 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, Status BuildComputation( const std::vector& args, const std::vector& arg_cores, - const std::vector& retvals, + const std::vector& retvals, const std::vector>& resources, bool return_updated_values_for_all_resources, xla::XlaBuilder* builder, xla::XlaComputation* computation, int* num_computation_outputs, int* num_nonconst_outputs, + std::vector* outputs, std::vector* resource_updates) { std::vector elems; elems.reserve(retvals.size()); - for (const XlaExpression& retval : retvals) { - if (!retval.has_constant_value()) { + for (int i = 0; i < retvals.size(); ++i) { + XlaCompiler::OutputDescription& output = (*outputs)[i]; + output.type = retvals[i].type; + output.shape = retvals[i].shape; + const XlaExpression& retval = retvals[i].expression; + if (retval.has_constant_value()) { + output.is_constant = true; + output.constant_value = retval.constant_value(); + } else { + output.is_constant = false; elems.push_back(retval.handle()); } } @@ -490,8 +501,8 @@ Status XlaCompiler::BuildArguments( std::vector arg_shapes(input_mapping->size()); for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { // Computes the shapes of non-constant arguments. - TF_RETURN_IF_ERROR( - XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i])); + TF_RETURN_IF_ERROR(XLAShapeForArgument( + args[(*input_mapping)[i]], is_entry_computation, &arg_shapes[i])); } if (use_tuple_arg) { @@ -567,7 +578,8 @@ Status XlaCompiler::BuildArguments( builder->ClearOpMetadata(); - // Fill in the handles in non-constant arguments. + // Fill in the handles in non-constant arguments, and reshape parameters + // back to their correct shapes. VLOG(2) << "XLA computation inputs:"; for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { const XlaCompiler::Argument& arg = args[input_mapping->at(i)]; @@ -586,7 +598,15 @@ Status XlaCompiler::BuildArguments( break; } case XlaCompiler::Argument::kParameter: - arg_expression.set_handle(arg_handles[i]); + // Reshape parameters back to their correct shapes. + // TODO(b/76097077): propagate device assignments onto arguments and + // return values of functions, and then reshape unconditionally. + if (is_entry_computation) { + arg_expression.set_handle( + builder->Reshape(arg_handles[i], arg.shape.dim_sizes())); + } else { + arg_expression.set_handle(arg_handles[i]); + } break; case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kInvalid: @@ -658,13 +678,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, // Converts Tensorflow's graph control-flow constructs into functional // control-flow that can be compiled into XLA code. TF_RETURN_IF_ERROR( - FunctionalizeControlFlow(graph.get(), local_flib_def_.get())); + FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), + graph.get(), local_flib_def_.get())); xla::XlaBuilder builder(name); - XlaContext* context = - new XlaContext(this, &builder, options_.allow_cpu_custom_calls, - options.resolve_compile_time_constants, - &options_.variable_representation_shape_fn); + XlaContext* context = new XlaContext( + this, &builder, options_.allow_cpu_custom_calls, + options.resolve_compile_time_constants, options.is_entry_computation, + &options_.shape_representation_fn); core::ScopedUnref context_unref(context); std::vector arg_expressions; @@ -681,35 +702,22 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, int num_nonconst_outputs; int num_computation_outputs; result->computation = std::make_shared(); + result->outputs.resize(context->retvals().size()); TF_RETURN_IF_ERROR(BuildComputation( args, arg_cores, context->retvals(), context->resources(), options.return_updated_values_for_all_resources, &builder, result->computation.get(), &num_computation_outputs, - &num_nonconst_outputs, &result->resource_updates)); + &num_nonconst_outputs, &result->outputs, &result->resource_updates)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; - result->outputs.resize(context->retvals().size()); - for (std::vector::size_type i = 0; - i < context->retvals().size(); ++i) { - const XlaExpression& retval = context->retvals()[i]; - if (retval.has_constant_value()) { - OutputDescription& output = result->outputs[i]; - output.shape = retval.constant_value().shape(); - output.is_constant = true; - output.constant_value = retval.constant_value(); - } - } - // Compute the output shapes, if there is a computation with non-constant + // Compute the XLA output shape, if there is a computation with non-constant // outputs. - auto computation_shape = client()->GetComputationShape(*result->computation); - if (!computation_shape.ok()) { - return computation_shape.status(); - } + TF_ASSIGN_OR_RETURN(std::unique_ptr computation_shape, + client()->GetComputationShape(*result->computation)); - result->xla_output_shape.Swap( - computation_shape.ValueOrDie()->mutable_result()); + result->xla_output_shape.Swap(computation_shape->mutable_result()); VLOG(2) << "XLA output shape: " << xla::ShapeUtil::HumanString(result->xla_output_shape); @@ -724,23 +732,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, // Tensorflow expects a major-to-minor order of results. xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); - // Converts the output shapes to TensorShapes. - int computation_output = 0; - for (std::vector::size_type i = 0; - i < context->retvals().size(); ++i) { - const XlaExpression& retval = context->retvals()[i]; - if (!retval.has_constant_value()) { - TF_RET_CHECK(computation_output < num_computation_outputs) - << "Computation has more outputs than expected"; - OutputDescription& output = result->outputs[i]; - output.is_constant = false; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape( - xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape, - computation_output), - &output.shape)); - ++computation_output; - } - } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index ca6cd822ef4effd48dbc3cc18d35d6642f303df1..bf496bd8bc81e67056eba380288bca88737cc00d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -38,7 +38,7 @@ class XlaContext; // It does a symbolic execution of the graph starting from specific input // shapes, using a JIT device to convert operators into XLA computations. // -// XlaCompiler is typically invoked from an `_XlaLaunch` operator once the +// XlaCompiler is typically invoked from an `XlaLaunch` operator once the // shapes of all input parameters to the computation are known. This is // because the symbolic execution requires known shapes for all operations. // @@ -67,6 +67,15 @@ class XlaContext; // _Retval values are ordered by _Retval index, whereas kResource values are // ordered by the original _Arg position of the variable. // +// If a shape representation function is provided as part of +// XlaCompiler::CompileOptions, kParameter arguments and return values to an +// entry computation will be reshaped in accordance to the shape function. +// Arguments and return values to a non-entry computation are not reshaped. +// Variable resource arguments are passed and returned in reshaped form, even +// for non-entry computations. This feature allows TensorFlow to keep on-device +// tensors with a different shape to their representation inside the XLA +// computation. +// // In both inputs and outputs, kResource values are placed the end. When // emitting While loop bodies, we must ensure that the loop body has // identical input and output signatures. By moving variable values @@ -171,7 +180,7 @@ class XlaCompiler { }; struct OutputDescription { - // Type and shape of the output. + // Type and shape of the output. The shape is the unflattened shape. DataType type; TensorShape shape; @@ -206,10 +215,12 @@ class XlaCompiler { // original arguments, and are not necessarily in the same order.) std::vector input_mapping; - // Input shapes of the computation. + // Input shapes of the computation. If we are flattening inputs, these are + // the flattened shapes. std::vector xla_input_shapes; - // Output shape in XLA format. The output shape is always a tuple. + // Output shape in XLA format. The output shape is always a tuple. If we + // are flattening outputs, these are the flattened shapes. xla::Shape xla_output_shape; // TensorFlow shapes of outputs, together with the values of any @@ -230,6 +241,8 @@ class XlaCompiler { std::shared_ptr computation; }; + typedef std::function + ShapeRepresentationFn; struct Options { // Name of the compilation device to use. Needs to be live only during // XlaCompiler's constructor. @@ -250,8 +263,7 @@ class XlaCompiler { // If set, the XLA representation of variables represented to XLA as the // shape given by this shape function. Variables are reshaped to this shape // on write, and reshaped to their original shape on read. - std::function - variable_representation_shape_fn; + ShapeRepresentationFn shape_representation_fn; // If not nullptr, populate_resource_manager is called with the // compilation device's resource manager when the compilation @@ -300,7 +312,8 @@ class XlaCompiler { // Returns the shape of the XLA parameter for an argument 'arg'. // See the class comment for more details about the argument passing // convention. - Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape); + Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation, + xla::Shape* xla_shape); // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 6b8918b26179735a4518a422fed024fa534122f5..55772ca324872f6d5fac008de7819b7fae64966a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -25,12 +25,14 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -225,7 +227,7 @@ TEST_F(XlaCompilerTest, Simple) { xla::Literal::CreateR1({4, 143}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { @@ -320,7 +322,8 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { xla::Literal::CreateR1({-7, -42}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE( + xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } { @@ -355,10 +358,80 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { xla::Literal::CreateR1({-7, -42}); std::unique_ptr expected = xla::Literal::MakeTuple({expected0.get(), expected1.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal)); } } +TEST_F(XlaCompilerTest, ConstantOutputsOfFunctionalNode) { + // Define a function with one compile-time constant output and one + // data-dependent output. + // @function.Defun(noinline=True) + // foo(a) {b=7; return b, a; } + const Tensor seven = test::AsScalar(7); + FunctionDef fdef = FunctionDefHelper::Create( + "foo", {"a_0:int32"}, {"const:int32", "a:int32"}, {}, + { + {{"Const"}, "Const", {}, {{"dtype", DT_INT32}, {"value", seven}}}, + }, + {{"a", "a_0"}, {"const", "Const:output:0"}}); + (*fdef.mutable_attr())["_noinline"].set_b(true); + FunctionDefLibrary fdef_lib; + *(fdef_lib.add_function()) = fdef; + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(fdef_lib)); + auto arg = ops::_Arg(scope.WithOpName("input_arg"), DT_INT32, 0); + NodeDef foo; + foo.set_name("foo"); + foo.set_op("foo"); + *foo.add_input() = "input_arg"; + Status status; + scope.graph()->AddNode(foo, &status); + TF_ASSERT_OK(status); + NodeDef retval_1; + retval_1.set_name("retval_0"); + retval_1.set_op(FunctionLibraryDefinition::kRetOp); + *retval_1.add_input() = "foo"; + (*retval_1.mutable_attr())["T"].set_type(DT_INT32); + (*retval_1.mutable_attr())["index"].set_i(0); + scope.graph()->AddNode(retval_1, &status); + TF_ASSERT_OK(status); + NodeDef retval_2; + retval_2.set_name("retval_1"); + retval_2.set_op(FunctionLibraryDefinition::kRetOp); + *retval_2.add_input() = "foo:1"; + (*retval_2.mutable_attr())["T"].set_type(DT_INT32); + (*retval_2.mutable_attr())["index"].set_i(1); + scope.graph()->AddNode(retval_2, &status); + TF_ASSERT_OK(status); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + } + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({1}); + + XlaCompiler::Options options = DefaultOptions(); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib); + options.flib_def = &flib_def; + XlaCompiler compiler(options); + + XlaCompiler::CompileOptions compile_options; + compile_options.resolve_compile_time_constants = true; + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants", + std::move(graph), args, &result)); + + ASSERT_EQ(2, result.outputs.size()); + EXPECT_TRUE(result.outputs[0].is_constant); + test::ExpectTensorEqual(result.outputs[0].constant_value, + test::AsScalar(7)); + EXPECT_FALSE(result.outputs[1].is_constant); +} + // Tests compilation and execution of a graph that adds two tensors. TEST_F(XlaCompilerTest, ResourceManager) { // Builds a graph that calls the dummy resource Op. @@ -523,7 +596,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { {output_base.get(), output_grad1.get(), output_grad2.get()}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({output_read.get(), output_resource.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } // Tests compilation and execution of a graph that adds two tensors. @@ -746,13 +819,10 @@ TEST_F(XlaCompilerTest, Variables) { xla::Literal::CreateR1({4, 143}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get(), expected1.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } -// Tests a simple graph that reads and writes a variable, with a -// variable_representation_shape_fn passed to the compiler that flattens all -// variable tensors to vectors. -TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { +xla::StatusOr> BuildTestGraph() { Scope scope = Scope::NewRootScope().ExitOnError(); auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); @@ -763,7 +833,15 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); - TF_ASSERT_OK(scope.ToGraph(graph.get())); + TF_RETURN_IF_ERROR(scope.ToGraph(graph.get())); + return std::move(graph); +} + +// Tests a simple graph that reads and writes a variable, with a +// shape_representation_fn passed to the compiler that flattens all +// variable tensors to vectors. +TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, BuildTestGraph()); // Builds a description of the arguments. std::vector args(2); @@ -778,15 +856,33 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { // Compiles the graph. XlaCompiler::Options options = DefaultOptions(); - options.variable_representation_shape_fn = [](const TensorShape& shape, - DataType type) { + options.shape_representation_fn = [](const TensorShape& shape, + DataType type) { return TensorShape({shape.num_elements()}); }; XlaCompiler compiler(options); + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = false; // Only reshape variables. + XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, &result)); + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr program_shape, + client_->GetComputationShape(*result.computation)); + + ASSERT_EQ(program_shape->parameters_size(), 2); + EXPECT_TRUE( + xla::ShapeUtil::Compatible(program_shape->parameters(0), + xla::ShapeUtil::MakeShape(xla::S32, {2, 2}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->result(), + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {2, 2}), + xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. std::unique_ptr param0_literal = @@ -811,7 +907,76 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { xla::Literal::CreateR1({26, 66, 34, 401}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get(), expected1.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); +} + +TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, BuildTestGraph()); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 2}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2, 2}); + + // Compiles the graph. + XlaCompiler::Options options = DefaultOptions(); + options.shape_representation_fn = [](const TensorShape& shape, + DataType type) { + return TensorShape({shape.num_elements()}); + }; + XlaCompiler compiler(options); + + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; // Reshape args and retvals. + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr program_shape, + client_->GetComputationShape(*result.computation)); + + ASSERT_EQ(program_shape->parameters_size(), 2); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->parameters(0), xla::ShapeUtil::MakeShape(xla::S32, {4}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->result(), + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {4}), + xla::ShapeUtil::MakeShape(xla::S32, {4})}))); + + // Tests that the generated computation works. + std::unique_ptr param0_literal = + xla::Literal::CreateR1({4, 55, 1, -3}); + std::unique_ptr param1_literal = + xla::Literal::CreateR1({22, 11, 33, 404}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + std::unique_ptr actual = + client_ + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + std::unique_ptr expected0 = + xla::Literal::CreateR1({27, 67, 35, 402}); + std::unique_ptr expected1 = + xla::Literal::CreateR1({26, 66, 34, 401}); + std::unique_ptr expected_literal = + xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } // namespace diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 3dd2d183f3a538786856dd8d92c5886b1cc237d8..098072d33cd4eb7f7dec0ec4196b43eca0220d4a 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -65,26 +65,30 @@ void XlaContext::set_args(std::vector args) { XlaContext::XlaContext( XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, + bool is_entry_computation, const std::function* - variable_representation_shape_fn) + shape_representation_fn) : compiler_(compiler), builder_(builder), allow_cpu_custom_calls_(allow_cpu_custom_calls), resolve_compile_time_constants_(resolve_compile_time_constants), - variable_representation_shape_fn_(variable_representation_shape_fn) {} + is_entry_computation_(is_entry_computation), + shape_representation_fn_(shape_representation_fn) {} string XlaContext::DebugString() { return "TLA JIT context"; } // This is called by the Retval Op to associate a computed value // with a specific return value of the subgraph. void XlaContext::AddRetval(int retval_index, DataType type, - const xla::XlaOp& handle) { + const TensorShape& shape, const xla::XlaOp& handle) { VLOG(1) << "Added retval index " << retval_index << " to XLA computation"; // Add the return value to the list being built up. if (retvals_.size() <= retval_index) { retvals_.resize(retval_index + 1); } - retvals_[retval_index].set_handle(handle); + XlaExpression e; + e.set_handle(handle); + retvals_[retval_index] = Retval{type, shape, e}; } Status XlaContext::AddConstRetval(int retval_index, DataType dtype, @@ -94,13 +98,11 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype, if (retvals_.size() <= retval_index) { retvals_.resize(retval_index + 1); } - if (resolve_compile_time_constants_) { - Tensor value; - TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); - retvals_[retval_index].set_constant_value(std::move(value)); - } else { - retvals_[retval_index].set_handle(builder_->ConstantLiteral(literal)); - } + Tensor value; + TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); + XlaExpression e; + e.set_constant_value(value); + retvals_[retval_index] = Retval{dtype, value.shape(), e}; return Status::OK(); } @@ -117,9 +119,9 @@ Status XlaContext::CreateResource( return Status::OK(); } -TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape, - DataType type) const { - return (*variable_representation_shape_fn_)(shape, type); +TensorShape XlaContext::RepresentationShape(const TensorShape& shape, + DataType type) const { + return (*shape_representation_fn_)(shape, type); } const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 1136ffe5073a8e7fd3c27d6ec7050cb1f8307584..341bf6ff1f37fa7cd81f41c02a941214067b1bd1 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -42,11 +42,13 @@ class XlaContext : public ResourceBase { static XlaContext& Get(const OpKernelContext* ctx); static XlaContext& Get(const XlaOpKernelContext* ctx); - // Creates a new XlaContext. + // Creates a new XlaContext. See the documentation on the class data fields + // for descriptions of the arguments. XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, + bool is_entry_computation, const std::function* - variable_representation_shape_fn); + shape_representation_fn); // Virtual method defined by ResourceBase. string DebugString() override; @@ -58,14 +60,26 @@ class XlaContext : public ResourceBase { bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } + bool resolve_compile_time_constants() const { + return resolve_compile_time_constants_; + } + bool is_entry_computation() const { return is_entry_computation_; } + const std::vector& args() const { return args_; } void set_args(std::vector args); - const std::vector& retvals() { return retvals_; } + struct Retval { + DataType type; + TensorShape shape; + // An XlaExpression representing the Retval's value. + XlaExpression expression; + }; + const std::vector& retvals() { return retvals_; } // This is called by the Retval Op to associate a computed value // with a specific return value of the subgraph. - void AddRetval(int retval_index, DataType type, const xla::XlaOp& handle); + void AddRetval(int retval_index, DataType type, const TensorShape& shape, + const xla::XlaOp& handle); // As for Retval, but for return values that are compile-time constants. Status AddConstRetval(int retval_index, DataType dtype, @@ -86,9 +100,9 @@ class XlaContext : public ResourceBase { } // Returns the XLA shape to be used to represent a variable of TF `shape` - // and `type`. - TensorShape VariableRepresentationShape(const TensorShape& shape, - DataType type) const; + // and `type`, or of an argument or return value of a top-level computation. + TensorShape RepresentationShape(const TensorShape& shape, + DataType type) const; // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a @@ -131,15 +145,23 @@ class XlaContext : public ResourceBase { std::vector args_; // Return values of the Tensorflow graph, indexed by _Retval index. - std::vector retvals_; + std::vector retvals_; // Holds ownership of resources. The resources are not ordered. std::vector> resources_; - // A function that describes how variable shapes should be represented - // in XLA. Variable values will be reshaped to this shape. Must be non-null. + // Is this a top-level computation, or an inner computation (e.g., a while + // body)? + const bool is_entry_computation_; + + // A function that describes how the shapes of + // a) argument and return value, for entry computations + // b) variables, for all computations, + // should be represented in XLA. Parameters/return values will be shaped + // according to this function, and reshaped back to/from their declared shapes + // for computations. Must be non-null. const std::function* - variable_representation_shape_fn_; + shape_representation_fn_; // Cache of prebuilt computations indexed by their type. using ComputationMap = std::map; diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 2b65f4d5d5936e062e5351a0723544191ffe2dfa..76c68d81af4dd9ec40fe6b1c33b03a876a0c6dc6 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -314,8 +314,8 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, } XlaContext& xla_context = XlaContext::Get(context_); - TensorShape representation_shape = xla_context.VariableRepresentationShape( - variable->shape(), variable->type()); + TensorShape representation_shape = + xla_context.RepresentationShape(variable->shape(), variable->type()); if (representation_shape == variable->shape()) { *value = variable->value(); } else { @@ -436,7 +436,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, XlaContext& xla_context = XlaContext::Get(context_); TensorShape representation_shape = - xla_context.VariableRepresentationShape(shape, type); + xla_context.RepresentationShape(shape, type); if (shape != representation_shape) { handle = builder()->Reshape(handle, representation_shape.dim_sizes()); } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index e309cb1e34db7f8430c2494c03aed41652b7a167..4692038b61f6871a8a16299fd4d11e963eb46a57 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -39,10 +39,10 @@ const char* const DEVICE_XLA_GPU = "XLA_GPU"; static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) { const OpDef* op_def; - TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("_XlaLaunch", &op_def)); + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("XlaLaunch", &op_def)); NodeDef node_def; node_def.set_name("_XlaLaunch-op"); - node_def.set_op("_XlaLaunch"); + node_def.set_op("XlaLaunch"); string kernel_class_name; TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr, &kernel_class_name)); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 1af9cb6d2ab15a33b56f1df0410f47d7e139a1ba..c6deb959a59f7b79500a0948b4035ea56cd9b4a1 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -99,8 +99,9 @@ cc_library( hdrs = ["service_interface.h"], visibility = [":friends"], deps = [ + ":status", + ":xla_data_proto", ":xla_proto", - "//tensorflow/core:lib", ], ) @@ -244,6 +245,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":protobuf_util", + ":status", ":status_macros", ":statusor", ":types", @@ -302,13 +304,13 @@ cc_library( ":array2d", ":array3d", ":array4d", - ":shape_tree", ":shape_util", ":sparse_index_array", ":status_macros", ":types", ":util", ":xla_data_proto", + "//tensorflow/core:framework", "//tensorflow/core:lib", ], ) @@ -323,12 +325,30 @@ tf_cc_test( ":shape_util", ":test", ":types", + "//tensorflow/compiler/tf2xla:common", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) +cc_library( + name = "error_spec", + hdrs = ["error_spec.h"], +) + +cc_library( + name = "literal_comparison", + srcs = ["literal_comparison.cc"], + hdrs = ["literal_comparison.h"], + deps = [ + ":error_spec", + ":literal_util", + ":util", + "//tensorflow/core:lib", + ], +) + cc_library( name = "metric_table_report", srcs = ["metric_table_report.cc"], @@ -563,6 +583,7 @@ tf_cc_test( ":shape_util", ":test", ":xla_data_proto", + "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index aac3273d5fd144f3b737529b0833c9328b3d0e4d..aacb394ae5f92aa0d87ee3a23bcc3d4ec5cd99a3 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -63,7 +63,6 @@ cc_library( srcs = ["client.cc"], hdrs = ["client.h"], deps = [ - ":computation", ":global_data", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", @@ -76,7 +75,7 @@ cc_library( "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", ], ) @@ -99,7 +98,6 @@ cc_library( hdrs = ["local_client.h"], deps = [ ":client", - ":computation", ":executable_build_options", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", @@ -126,7 +124,6 @@ cc_library( hdrs = ["compile_only_client.h"], deps = [ ":client", - ":computation", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -162,47 +159,6 @@ cc_library( ], ) -cc_library( - name = "computation", - srcs = ["computation.cc"], - hdrs = ["computation.h"], - deps = [ - "//tensorflow/compiler/xla:service_interface", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/service:session_proto", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "computation_builder", - srcs = ["computation_builder.cc"], - hdrs = ["computation_builder.h"], - deps = [ - ":client", - ":computation", - ":global_data", - ":padding", - "//tensorflow/compiler/xla:array", - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:array3d", - "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:lib", - ], -) - cc_library( name = "sharding_builder", srcs = ["sharding_builder.cc"], diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 328e1b8fa84e7baaca41c6c9a65e9a1598ac32ae..c9d275a77b5cd40225f4b5c45e02c242d27d9aa1 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -161,22 +161,6 @@ Status Client::ResetDevice() { return Status::OK(); } -StatusOr> Client::ExecuteAndTransfer( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options, - ExecutionProfile* execution_profile) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr data, - Execute(computation, arguments, execution_options, execution_profile)); - - const Shape* shape_with_output_layout = nullptr; - if (execution_options && execution_options->has_shape_with_output_layout()) { - shape_with_output_layout = &execution_options->shape_with_output_layout(); - } - return Transfer(*data, shape_with_output_layout); -} - StatusOr> Client::ExecuteAndTransfer( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -221,65 +205,11 @@ StatusOr> Client::ComputeConstant( return Literal::CreateFromProto(response.literal()); } -StatusOr Client::LoadSnapshot(const SessionModule& module) { - LoadComputationSnapshotRequest request; - *request.mutable_module() = module; - LoadComputationSnapshotResponse response; - - Status s = stub_->LoadComputationSnapshot(&request, &response); - if (!s.ok()) { - return s; - } - - VLOG(1) << "load snapshot response: " << response.ShortDebugString(); - return Computation(stub_, response.computation()); -} - StatusOr Client::LoadSnapshot(const HloSnapshot& module) { TF_RET_CHECK(module.has_hlo() && module.hlo().has_hlo_module()); return XlaComputation(module.hlo().hlo_module()); } -StatusOr> Client::Execute( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options, - ExecutionProfile* execution_profile) { - ExecuteRequest request; - *request.mutable_computation() = computation.handle(); - - if (execution_options == nullptr) { - *request.mutable_execution_options() = CreateDefaultExecutionOptions(); - } else { - *request.mutable_execution_options() = *execution_options; - } - for (GlobalData* argument : arguments) { - CHECK(argument != nullptr) << "Argument pointers must not be null."; - *request.add_arguments() = argument->handle(); - } - - ExecuteResponse response; - VLOG(1) << "making execute request: " << request.ShortDebugString(); - Status s = stub_->Execute(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - if (execution_profile != nullptr) { - *execution_profile = response.profile(); - if (VLOG_IS_ON(1)) { - TF_ASSIGN_OR_RETURN( - auto execution_stats, - ExecutionStatsAsString(computation, response.profile())); - VLOG(1) << execution_stats; - } - } - - return MakeUnique(stub_, response.output()); -} - StatusOr> Client::Execute( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -320,41 +250,6 @@ StatusOr> Client::Execute( return MakeUnique(stub_, response.output()); } -StatusOr>> Client::ExecuteParallel( - tensorflow::gtl::ArraySlice computations) { - ExecuteParallelRequest request; - - for (const ComputationInstance& computation : computations) { - ExecuteRequest single_request; - *single_request.mutable_computation() = computation.computation.handle(); - for (GlobalData* argument : computation.arguments) { - *single_request.add_arguments() = argument->handle(); - } - *single_request.mutable_execution_options() = computation.execution_options; - *request.add_requests() = single_request; - } - - ExecuteParallelResponse response; - VLOG(1) << "making execute-parallel request: " << request.ShortDebugString(); - tensorflow::Status s = stub_->ExecuteParallel(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - std::vector> outputs; - for (size_t i = 0; i < computations.size(); ++i) { - outputs.push_back( - MakeUnique(stub_, response.responses(i).output())); - if (computations[i].execution_profile != nullptr) { - *computations[i].execution_profile = response.responses(i).profile(); - } - } - - return std::move(outputs); -} - StatusOr>> Client::ExecuteParallel( tensorflow::gtl::ArraySlice computations) { ExecuteGraphParallelRequest request; @@ -372,7 +267,7 @@ StatusOr>> Client::ExecuteParallel( ExecuteParallelResponse response; VLOG(1) << "making execute-graph-parallel request: " << request.ShortDebugString(); - tensorflow::Status s = stub_->ExecuteGraphParallel(&request, &response); + Status s = stub_->ExecuteGraphParallel(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { @@ -401,7 +296,7 @@ StatusOr> Client::GetDeviceHandles( GetDeviceHandlesResponse response; VLOG(1) << "making get device request: " << request.ShortDebugString(); - tensorflow::Status s = stub_->GetDeviceHandles(&request, &response); + Status s = stub_->GetDeviceHandles(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { @@ -449,24 +344,6 @@ StatusOr>> Client::DeconstructTuple( return std::move(handles); } -StatusOr Client::GetComputationStats( - const Computation& computation, const DebugOptions& debug_options) const { - ComputationStatsRequest request; - *request.mutable_computation() = computation.handle(); - *request.mutable_debug_options() = debug_options; - ComputationStatsResponse response; - - VLOG(1) << "making computation stats request"; - Status s = stub_->GetComputationStats(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - CHECK(response.has_stats()); - return response.stats(); -} - StatusOr Client::GetComputationStats( const XlaComputation& computation, const DebugOptions& debug_options) const { @@ -488,23 +365,6 @@ StatusOr Client::GetComputationStats( return response.stats(); } -StatusOr> Client::GetComputationShape( - const Computation& computation) { - GetComputationShapeRequest request; - *request.mutable_computation() = computation.handle(); - GetComputationShapeResponse response; - - VLOG(1) << "making get-computation-shape request"; - Status s = stub_->GetComputationShape(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - return WrapUnique(response.release_program_shape()); -} - StatusOr> Client::GetComputationShape( const XlaComputation& computation) { TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape()); @@ -527,28 +387,6 @@ StatusOr Client::GetShape(const GlobalData& data) { return response.shape(); } -StatusOr Client::ExecutionStatsAsString( - const Computation& computation, const ExecutionProfile& profile) { - TF_ASSIGN_OR_RETURN( - auto computation_stats, - GetComputationStats(computation, - legacy_flags::GetDebugOptionsFromFlags())); - int64 total_flops = - computation_stats.flop_count() + computation_stats.transcendental_count(); - if (profile.compute_time_ns() > 0) { - int64 nanoseconds = profile.compute_time_ns(); - int64 cycle_count = profile.compute_cycle_count(); - double gflops = total_flops / nanoseconds; - return tensorflow::strings::StrCat( - "[Execution Statistics] flop count: ", computation_stats.flop_count(), - ", transcendental count: ", computation_stats.transcendental_count(), - ", compute execution time: ", nanoseconds, " nsec", - ", compute cycles: ", cycle_count, ", performance: ", gflops, - "gflop/s"); - } - return string("[Execution Statistics] not available."); -} - StatusOr Client::ExecutionStatsAsString( const XlaComputation& computation, const ExecutionProfile& profile) { TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index a63ff4c56d1dd78c7abfa2bf163b5fbd54d82b2b..d57e2536d0b44cda46d7c1c2513b82c9f8a31c1b 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -19,11 +19,10 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -52,21 +51,6 @@ class Client { // device is chosen by the service. // * If execution_profile is not nullptr then the pointed-to ExecutionProfile // will be filled with profile data from the execution. - StatusOr> Execute( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options = nullptr, - ExecutionProfile* execution_profile = nullptr); - - // Executes the computation with the given arguments and returns the global - // data that was produced from the execution. - // * If execution_options is not nullptr, these options are passed to the - // service to affect how it compiles our computation. (The pointer does not - // need to live beyond this call.) - // * If execution_profile is not nullptr then the pointed-to ExecutionProfile - // will be filled with profile data from the execution. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> Execute( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -78,34 +62,6 @@ class Client { // executed on the devices associated with the handles by partitioning the // computation based on the attached sharding attributes. Otherwise, a // device is chosen by the service. - struct ComputationInstance { - const Computation& computation; - std::vector arguments; - ExecutionOptions execution_options; - ExecutionProfile* execution_profile; - - ComputationInstance(const Computation& computation, - std::vector arguments, - ExecutionOptions execution_options, - ExecutionProfile* execution_profile) - : computation(computation), - arguments(std::move(arguments)), - execution_options(execution_options), - execution_profile(execution_profile) {} - }; - - // Executes a list ComputationInstances and returns global data produced from - // each computation. - StatusOr>> ExecuteParallel( - tensorflow::gtl::ArraySlice computations); - - // A struct to represent a computation instance to be executed. - // * If execution_options.device_handles is not empty, the computation is - // executed on the devices associated with the handles by partitioning the - // computation based on the attached sharding attributes. Otherwise, a - // device is chosen by the service. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. struct XlaComputationInstance { const XlaComputation& computation; std::vector arguments; @@ -125,7 +81,6 @@ class Client { // Executes a list XlaComputationInstances and returns global data produced // from each computation. // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr>> ExecuteParallel( tensorflow::gtl::ArraySlice computations); @@ -177,17 +132,6 @@ class Client { // Executes the computation with the given arguments and transfers the result // to the client as a literal. Parameters are defined the same as for // Execute() and Transfer(). - StatusOr> ExecuteAndTransfer( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options = nullptr, - ExecutionProfile* execution_profile = nullptr); - - // Executes the computation with the given arguments and transfers the result - // to the client as a literal. Parameters are defined the same as for - // Execute() and Transfer(). - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> ExecuteAndTransfer( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -223,12 +167,6 @@ class Client { const GlobalData& data); // Retrieves the statistics of the given computation. - StatusOr GetComputationStats( - const Computation& computation, const DebugOptions& debug_options) const; - - // Retrieves the statistics of the given computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr GetComputationStats( const XlaComputation& computation, const DebugOptions& debug_options) const; @@ -239,13 +177,6 @@ class Client { // As above, but returns the shape of the provided computation (parameter // types/names and return type). - StatusOr> GetComputationShape( - const Computation& computation); - - // As above, but returns the shape of the provided computation (parameter - // types/names and return type). - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> GetComputationShape( const XlaComputation& computation); @@ -253,9 +184,6 @@ class Client { // two computations via a pair of Send and Recv instructions. StatusOr CreateChannelHandle(); - StatusOr LoadSnapshot(const SessionModule& module); - - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr LoadSnapshot(const HloSnapshot& module); ServiceInterface* stub() { return stub_; } @@ -263,8 +191,6 @@ class Client { private: // Returns the execution statistics (e.g., gflop/s) as a string from the // ExecutionProfile returned from an execution of the computation. - StatusOr ExecutionStatsAsString(const Computation& computation, - const ExecutionProfile& profile); StatusOr ExecutionStatsAsString(const XlaComputation& computation, const ExecutionProfile& profile); diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index 96e38bca01087991943aff40ed1cb3e21f9e6cba..dc69d2097ebe14ca0e14a39849d4fcae99024fdc 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -21,24 +21,6 @@ limitations under the License. namespace xla { -StatusOr>> -CompileOnlyClient::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { - std::vector service_instances; - service_instances.reserve(computations.size()); - for (const AotComputationInstance& instance : computations) { - service_instances.push_back({}); - CompileOnlyService::AotComputationInstance& service_instance = - service_instances.back(); - TF_RET_CHECK(instance.computation != nullptr); - service_instance.computation = instance.computation->handle(); - service_instance.argument_layouts = instance.argument_layouts; - service_instance.result_layout = instance.result_layout; - } - return compiler_service_->CompileAheadOfTime(service_instances, options); -} - StatusOr>> CompileOnlyClient::CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index c8725b8517484acdaf093bc3b34adb00f69155b1..f9a7c31270c7a11175f47a537639a97d0c9211af 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/compile_only_service.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -38,26 +37,7 @@ class CompileOnlyClient : public Client { CompileOnlyClient(const CompileOnlyClient&) = delete; void operator=(const CompileOnlyClient&) = delete; - // A description of a computation to compile using CompileAheadOfTime. - struct AotComputationInstance { - const Computation* computation; - // Inform the compiler of the expected layout for arguments. - std::vector argument_layouts; - // Specifies the expected result layout. - const Shape* result_layout; - }; - - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. The |options| parameter describes - // the target for which the compiler should emit code. - StatusOr>> - CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options); - // A description of an xla computation to compile using CompileAheadOfTime. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. struct AotXlaComputationInstance { const XlaComputation* computation; // Inform the compiler of the expected layout for arguments. @@ -69,8 +49,6 @@ class CompileOnlyClient : public Client { // Compiles a list of xla computations for ahead-of-time execution. This is // intended for use in static compilation. The |options| parameter describes // the target for which the compiler should emit code. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr>> CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, diff --git a/tensorflow/compiler/xla/client/computation.cc b/tensorflow/compiler/xla/client/computation.cc deleted file mode 100644 index e6c57bda0f0c4cb969939883efebcf3a6d6be381..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/computation.cc +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/client/computation.h" - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace xla { - -Computation::Computation() : parent_(nullptr) {} - -Computation::Computation(ServiceInterface* parent, - const ComputationHandle& handle) - : handle_(handle), parent_(parent) {} - -Computation::Computation(Computation&& computation) - : handle_(std::move(computation.handle_)), parent_(computation.parent_) { - computation.ResetWithoutFreeing(); -} - -void Computation::Reset() { - // TODO(b/34469253) deallocate any owned computation. - ResetWithoutFreeing(); -} - -StatusOr> Computation::Snapshot() const { - SnapshotComputationRequest request; - *request.mutable_computation() = handle_; - SnapshotComputationResponse response; - - TF_RETURN_IF_ERROR(parent_->SnapshotComputation(&request, &response)); - - return WrapUnique(response.release_module()); -} - -Computation::~Computation() { Reset(); } - -Computation& Computation::operator=(Computation&& computation) { - if (&computation != this) { - Reset(); - handle_ = computation.handle_; - parent_ = computation.parent_; - computation.ResetWithoutFreeing(); - } - return *this; -} - -void Computation::ResetWithoutFreeing() { - handle_.Clear(); - parent_ = nullptr; -} - -StatusOr Computation::GetProgramShape() const { - GetComputationShapeRequest request; - *request.mutable_computation() = handle_; - GetComputationShapeResponse response; - - TF_RETURN_IF_ERROR(parent_->GetComputationShape(&request, &response)); - - return std::move(*response.mutable_program_shape()); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/client/computation.h b/tensorflow/compiler/xla/client/computation.h deleted file mode 100644 index 9a1bcde76387297cb7f374b25baad1d5ec284859..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/computation.h +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ - -#include - -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service_interface.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" - -namespace xla { - -// Wraps a ComputationHandle protobuf with a lifetime. Computation is -// movable and not copyable to capture the same kind of unique -// ownership that std::unique_ptr represents. -// -// TODO(b/74197823): Deprecated. Use XlaComputation instead. -class Computation { - public: - // Creates a null Computation. - Computation(); - - // parent: stub for the service on which we will deallocate the computation - // when it is no longer needed. - // handle: the computation handle protobuf from the service. - Computation(ServiceInterface* parent, const ComputationHandle& handle); - - Computation(Computation&& computation); - - // Deallocates the computation. - ~Computation(); - - Computation& operator=(Computation&& computation); - - // Returns the underlying handle. - const ComputationHandle& handle() const { return handle_; } - - // Sets handle to a null state and clears any owned computation. - void Reset(); - - // Requests that we snapshot the computation into a serializable protocol - // buffer form. - StatusOr> Snapshot() const; - - // Returns true if this object is a null Computation. - bool IsNull() const { return parent_ == nullptr; } - - // Returns the "program shape" (parameter and return shapes) for this - // computation. - StatusOr GetProgramShape() const; - - private: - void ResetWithoutFreeing(); - - ComputationHandle handle_; // Handle that is wrapped by this class. - - // Stub that the handle is deallocated on when this object's lifetime ends. - ServiceInterface* parent_; - - TF_DISALLOW_COPY_AND_ASSIGN(Computation); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc deleted file mode 100644 index 83c7cb174402133706fbde6a734a29afd8edfe80..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ /dev/null @@ -1,1574 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/client/computation_builder.h" - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/protobuf.h" - -namespace xla { - -ComputationBuilder::ComputationBuilder(Client* client, - const string& computation_name) - : name_(computation_name), client_(client) {} - -ComputationBuilder::~ComputationBuilder() {} - -void ComputationBuilder::NoteError(const Status& error) { - if (die_immediately_on_error_) { - LOG(FATAL) << "error building computation: " << error; - } - - if (first_error_.ok()) { - first_error_ = error; - first_error_backtrace_.CreateCurrent(/*skip_count=*/1); - } -} - -std::unique_ptr ComputationBuilder::CreateSubBuilder( - const string& computation_name) { - auto sub_builder = MakeUnique(client_, computation_name); - sub_builder->parent_builder_ = this; - sub_builder->die_immediately_on_error_ = die_immediately_on_error_; - return sub_builder; -} - -Status ComputationBuilder::PrepareComputation() { - TF_RETURN_IF_ERROR(first_error_); - - if (!computation_.IsNull()) { - return Status::OK(); - } - - ComputationRequest request; - request.set_name(name_); - ComputationResponse response; - - VLOG(2) << "making computation request"; - Status s = client_->stub()->Computation(&request, &response); - VLOG(2) << "done with computation request"; - - if (!s.ok()) { - NoteError(s); - return first_error_; - } - - computation_ = Computation(client_->stub(), response.computation()); - return Status::OK(); -} - -Status ComputationBuilder::RunOp(OpRequest* op_request, - OpResponse* op_response) { - TF_RETURN_IF_ERROR(first_error_); - TF_RETURN_IF_ERROR(PrepareComputation()); - - // Fill in fields that are set on every OpRequest. - *op_request->mutable_computation() = computation_.handle(); - *op_request->mutable_metadata() = metadata_; - if (sharding_) { - *op_request->mutable_sharding() = *sharding_; - } - - const string& op_name = - OpRequest::descriptor()->FindFieldByNumber(op_request->op_case())->name(); - VLOG(2) << "running op request: " << op_name; - Status status = client_->stub()->Op(op_request, op_response); - VLOG(2) << "done with op request: " << op_name; - return status; -} - -void ComputationBuilder::RunOpAndNoteError(OpRequest* op_request) { - OpResponse op_response; - Status status = RunOp(op_request, &op_response); - if (!status.ok()) { - NoteError(status); - } -} - -ComputationDataHandle ComputationBuilder::RunOpAndParseResponse( - OpRequest* op_request) { - OpResponse op_response; - Status status = RunOp(op_request, &op_response); - if (!status.ok()) { - NoteError(status); - return ComputationDataHandle(); - } - if (op_response.output().handle() == 0) { - NoteError(InternalError("No output handle")); - return ComputationDataHandle(); - } - return op_response.output(); -} - -bool ComputationBuilder::MakeWindow( - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, Window* window) { - const auto verify_size = [&](const size_t x, const char* x_name) { - if (x == 0 || x == window_dimensions.size()) { - return true; - } else { - NoteError(InvalidArgument( - "%s", tensorflow::strings::StrCat( - "Window has different number of window dimensions than of ", - x_name, "\nNumber of window dimensions: ", - window_dimensions.size(), "\nNumber of ", x_name, ": ", x, - "\n") - .c_str())); // - return false; - } - }; - if (!verify_size(window_strides.size(), "window strides") || - !verify_size(padding.size(), "padding entries") || - !verify_size(lhs_dilation.size(), "lhs dilation factors") || - !verify_size(rhs_dilation.size(), "rhs dilation factors")) { - return false; - } - - window->Clear(); - for (size_t i = 0; i < window_dimensions.size(); i++) { - auto dim = window->add_dimensions(); - dim->set_size(window_dimensions[i]); - if (!window_strides.empty()) { - dim->set_stride(window_strides[i]); - } else { - dim->set_stride(1); - } - if (!padding.empty()) { - dim->set_padding_low(padding[i].first); - dim->set_padding_high(padding[i].second); - } else { - dim->set_padding_low(0); - dim->set_padding_high(0); - } - if (!lhs_dilation.empty()) { - dim->set_base_dilation(lhs_dilation[i]); - } else { - dim->set_base_dilation(1); - } - if (!rhs_dilation.empty()) { - dim->set_window_dilation(rhs_dilation[i]); - } else { - dim->set_window_dilation(1); - } - dim->set_window_reversal(false); - } - return true; -} - -ComputationDataHandle ComputationBuilder::ConstantLiteral( - const Literal& literal) { - OpRequest op_request; - ConstantRequest* request = op_request.mutable_constant_request(); - *request->mutable_literal() = literal.ToProto(); - VLOG(3) << "created constant: " << request->literal().ShortDebugString(); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number, - const Shape& shape, - const string& name) { - OpRequest op_request; - ParameterRequest* request = op_request.mutable_parameter_request(); - *request->mutable_shape() = shape; - request->set_parameter(parameter_number); - request->set_name(name); - return RunOpAndParseResponse(&op_request); -} - -StatusOr> ComputationBuilder::GetShapeWithoutNoteError( - const ComputationDataHandle& operand) { - GetLocalShapeRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - GetLocalShapeResponse response; - - VLOG(2) << "making get-shape request"; - TF_RETURN_IF_ERROR(client_->stub()->GetLocalShape(&request, &response)); - VLOG(2) << "done with request"; - - TF_RET_CHECK(response.has_shape()); - std::unique_ptr shape = WrapUnique(response.release_shape()); - TF_RET_CHECK(shape != nullptr); - return std::move(shape); -} - -StatusOr> ComputationBuilder::GetShape( - const ComputationDataHandle& operand) { - TF_RETURN_IF_ERROR(first_error_); - - auto status_or_shape = GetShapeWithoutNoteError(operand); - if (!status_or_shape.ok()) { - NoteError(status_or_shape.status()); - return first_error_; - } - return status_or_shape; -} - -StatusOr ComputationBuilder::GetProgramShape() { - TF_RETURN_IF_ERROR(first_error_); - - GetComputationShapeRequest request; - *request.mutable_computation() = computation_.handle(); - GetComputationShapeResponse response; - - VLOG(2) << "making get-program-shape-request"; - Status status = client_->stub()->GetComputationShape(&request, &response); - VLOG(2) << "done with get-program-shape-request"; - - if (!status.ok()) { - first_error_ = status; - return status; - } - - TF_RET_CHECK(response.has_program_shape()); - return std::move(*response.mutable_program_shape()); -} - -ComputationDataHandle ComputationBuilder::Slice( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { - OpRequest op_request; - SliceRequest* request = op_request.mutable_slice_request(); - *request->mutable_operand() = operand; - for (int64 index : start_indices) { - request->add_start_indices(index); - } - for (int64 index : limit_indices) { - request->add_limit_indices(index); - } - for (int64 index : strides) { - request->add_strides(index); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::SliceInDim( - const ComputationDataHandle& operand, int64 start_index, int64 limit_index, - int64 stride, int64 dimno) { - StatusOr> shape_status = GetShape(operand); - if (!shape_status.ok()) { - NoteError(shape_status.status()); - return ComputationDataHandle{}; - } - const Shape& shape = *shape_status.ValueOrDie(); - std::vector starts(ShapeUtil::Rank(shape), 0); - std::vector limits(shape.dimensions().begin(), - shape.dimensions().end()); - std::vector strides(ShapeUtil::Rank(shape), 1); - starts[dimno] = start_index; - limits[dimno] = limit_index; - strides[dimno] = stride; - return Slice(operand, starts, limits, strides); -} - -ComputationDataHandle ComputationBuilder::DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { - OpRequest op_request; - DynamicSliceRequest* request = op_request.mutable_dynamic_slice_request(); - *request->mutable_operand() = operand; - *request->mutable_start_indices() = start_indices; - for (int64 index : slice_sizes) { - request->add_slice_sizes(index); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices) { - OpRequest op_request; - DynamicUpdateSliceRequest* request = - op_request.mutable_dynamic_update_slice_request(); - *request->mutable_operand() = operand; - *request->mutable_update() = update; - *request->mutable_start_indices() = start_indices; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension) { - OpRequest op_request; - ConcatenateRequest* request = op_request.mutable_concatenate_request(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - request->set_dimension(dimension); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Broadcast( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice broadcast_sizes) { - OpRequest op_request; - BroadcastRequest* request = op_request.mutable_broadcast_request(); - *request->mutable_operand() = operand; - for (int64 size : broadcast_sizes) { - request->add_broadcast_sizes(size); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Pad( - const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config) { - OpRequest op_request; - PadRequest* request = op_request.mutable_pad_request(); - *request->mutable_operand() = operand; - *request->mutable_padding_value() = padding_value; - *request->mutable_padding_config() = padding_config; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Reshape( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { - OpRequest op_request; - ReshapeRequest* request = op_request.mutable_reshape_request(); - *request->mutable_operand() = operand; - for (int64 dimension : dimensions) { - request->add_dimensions(dimension); - } - for (int64 new_size : new_sizes) { - request->add_new_sizes(new_size); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Reshape( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice new_sizes) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - std::vector dimensions(shape.ValueOrDie()->dimensions().size()); - std::iota(dimensions.begin(), dimensions.end(), 0); - return Reshape(operand, dimensions, new_sizes); -} - -ComputationDataHandle ComputationBuilder::Collapse( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - // Don't support out-of-order collapse here. - // Checks that the collapsed dimensions are in order and consecutive. - for (tensorflow::gtl::ArraySlice::size_type i = 1; - i < dimensions.size(); ++i) { - if (dimensions[i] - 1 != dimensions[i - 1]) { - NoteError(InvalidArgument( - "Collapsed dimensions are not in order and consecutive.")); - return ComputationDataHandle(); - } - } - - // Create a new sizes vector from the old shape, replacing the collapsed - // dimensions by the product of their sizes. - StatusOr> shape_or_status = GetShape(operand); - if (!shape_or_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr original_shape = shape_or_status.ConsumeValueOrDie(); - - VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape); - VLOG(3) << "dims to collapse: " - << tensorflow::str_util::Join(dimensions, ","); - - if (dimensions.size() <= 1) { - // Not collapsing anything, trivially we can return the operand versus - // enqueueing a trivial reshape. - return operand; - } - - std::vector new_sizes; - for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) { - if (i <= dimensions.front() || i > dimensions.back()) { - new_sizes.push_back(original_shape->dimensions(i)); - } else { - new_sizes.back() *= original_shape->dimensions(i); - } - } - - VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",") - << "]"; - - return Reshape(operand, new_sizes); -} - -void ComputationBuilder::Trace(const string& tag, - const ComputationDataHandle& operand) { - OpRequest op_request; - TraceRequest* request = op_request.mutable_trace_request(); - request->set_tag(tag); - *request->mutable_operand() = operand; - RunOpAndNoteError(&op_request); -} - -ComputationDataHandle ComputationBuilder::Select( - const ComputationDataHandle& pred, const ComputationDataHandle& on_true, - const ComputationDataHandle& on_false) { - return TernaryOp(TRIOP_SELECT, pred, on_true, on_false); -} - -ComputationDataHandle ComputationBuilder::Tuple( - tensorflow::gtl::ArraySlice elements) { - OpRequest op_request; - VariadicOpRequest* request = op_request.mutable_variadic_op_request(); - request->set_varop(VAROP_TUPLE); - for (const ComputationDataHandle& operand : elements) { - *request->add_operands() = operand; - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::GetTupleElement( - const ComputationDataHandle& tuple_data, int64 index) { - OpRequest op_request; - GetTupleElementRequest* request = - op_request.mutable_get_tuple_element_request(); - *request->mutable_operand() = tuple_data; - request->set_index(index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Eq( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_EQ, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Ne( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_NE, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Ge( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_GE, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Gt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_GT, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Le( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_LE, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Lt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_LT, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Dot( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - StatusOr> lhs_shape_or_status = GetShape(lhs); - if (!lhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); - - DotDimensionNumbers dimension_numbers; - dimension_numbers.add_lhs_contracting_dimensions( - lhs_shape->dimensions_size() == 1 ? 0 : 1); - dimension_numbers.add_rhs_contracting_dimensions(0); - return DotGeneral(lhs, rhs, dimension_numbers); -} - -ComputationDataHandle ComputationBuilder::DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - const DotDimensionNumbers& dimension_numbers) { - OpRequest op_request; - DotRequest* request = op_request.mutable_dot_request(); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - *request->mutable_dimension_numbers() = dimension_numbers; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Conv( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { - return ConvWithGeneralDimensions( - lhs, rhs, window_strides, padding, - CreateDefaultConvDimensionNumbers(window_strides.size())); -} - -ComputationDataHandle ComputationBuilder::ConvWithGeneralPadding( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { - return ConvGeneral(lhs, rhs, window_strides, padding, - CreateDefaultConvDimensionNumbers(window_strides.size())); -} - -bool ComputationBuilder::VerifyConvolution( - const Shape& lhs_shape, const Shape& rhs_shape, - const ConvolutionDimensionNumbers& dimension_numbers) { - if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) { - NoteError( - InvalidArgument("Convolution arguments must have same number of " - "dimensions. Got: %s and %s", - ShapeUtil::HumanString(lhs_shape).c_str(), - ShapeUtil::HumanString(rhs_shape).c_str())); - return false; - } - int num_dims = ShapeUtil::Rank(lhs_shape); - if (num_dims < 2) { - NoteError(InvalidArgument( - "Convolution expects argument arrays with >= 3 dimensions. " - "Got: %s and %s", - ShapeUtil::HumanString(lhs_shape).c_str(), - ShapeUtil::HumanString(rhs_shape).c_str())); - return false; - } - int num_spatial_dims = num_dims - 2; - - const auto check_spatial_dimensions = - [&](const char* const field_name, - const tensorflow::protobuf::RepeatedField& - numbers) { - if (numbers.size() != num_spatial_dims) { - NoteError(InvalidArgument("Expected %d elements for %s, but got %d.", - num_spatial_dims, field_name, - numbers.size())); - return false; - } - for (int i = 0; i < numbers.size(); ++i) { - if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) { - NoteError( - InvalidArgument("Convolution %s[%d] is out of bounds: %lld", - field_name, i, numbers.Get(i))); - return false; - } - } - return true; - }; - return check_spatial_dimensions( - "input_spatial_dimensions", - dimension_numbers.input_spatial_dimensions()) && - check_spatial_dimensions( - "kernel_spatial_dimensions", - dimension_numbers.kernel_spatial_dimensions()) && - check_spatial_dimensions( - "output_spatial_dimensions", - dimension_numbers.output_spatial_dimensions()); -} - -ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> lhs_shape_or_status = GetShape(lhs); - if (!lhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - StatusOr> rhs_shape_or_status = GetShape(rhs); - if (!rhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); - std::unique_ptr rhs_shape = rhs_shape_or_status.ConsumeValueOrDie(); - - if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) { - NoteError(InternalError("failed to verify convolution")); - return ComputationDataHandle(); - } - - std::vector base_area_dimensions( - dimension_numbers.input_spatial_dimensions_size()); - for (std::vector::size_type i = 0; i < base_area_dimensions.size(); - ++i) { - base_area_dimensions[i] = - lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i)); - } - - std::vector window_dimensions( - dimension_numbers.kernel_spatial_dimensions_size()); - for (std::vector::size_type i = 0; i < window_dimensions.size(); ++i) { - window_dimensions[i] = - rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); - } - - return ConvGeneral(lhs, rhs, window_strides, - MakePadding(base_area_dimensions, window_dimensions, - window_strides, padding), - dimension_numbers); -} - -ComputationDataHandle ComputationBuilder::ConvGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers) { - return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, - dimension_numbers); -} - -ComputationDataHandle ComputationBuilder::ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> lhs_shape_or_status = GetShape(lhs); - if (!lhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - StatusOr> rhs_shape_or_status = GetShape(rhs); - if (!rhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); - std::unique_ptr rhs_shape = rhs_shape_or_status.ConsumeValueOrDie(); - if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) { - // Error is recorded in VerifyConvolution. - return ComputationDataHandle(); - } - - std::vector window_dimensions( - dimension_numbers.kernel_spatial_dimensions_size()); - for (std::vector::size_type i = 0; i < window_dimensions.size(); ++i) { - window_dimensions[i] = - rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); - } - - OpRequest op_request; - ConvolveRequest* request = op_request.mutable_convolve_request(); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - *request->mutable_dimension_numbers() = dimension_numbers; - - if (!MakeWindow(window_dimensions, window_strides, padding, lhs_dilation, - rhs_dilation, request->mutable_window())) { - // Error is recorded in MakeWindow. - return ComputationDataHandle(); - } - - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Fft( - const ComputationDataHandle& operand, const FftType fft_type, - const tensorflow::gtl::ArraySlice fft_length) { - OpRequest op_request; - FftRequest* request = op_request.mutable_fft_request(); - *request->mutable_operand() = operand; - request->set_fft_type(fft_type); - for (int64 dim_len : fft_length) { - request->add_fft_length(dim_len); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape, - const string& config) { - OpRequest op_request; - InfeedRequest* request = op_request.mutable_infeed_request(); - *request->mutable_shape() = shape; - *request->mutable_config() = config; - return RunOpAndParseResponse(&op_request); -} - -void ComputationBuilder::Outfeed(const ComputationDataHandle& operand, - const Shape& shape_with_layout, - const string& outfeed_config) { - OpRequest op_request; - OutfeedRequest* request = op_request.mutable_outfeed_request(); - request->set_outfeed_config(outfeed_config); - *request->mutable_operand() = operand; - *request->mutable_shape() = shape_with_layout; - RunOpAndNoteError(&op_request); -} - -ComputationDataHandle ComputationBuilder::Call( - const Computation& computation, - tensorflow::gtl::ArraySlice operands) { - OpRequest op_request; - CallRequest* request = op_request.mutable_call_request(); - *request->mutable_to_apply() = computation.handle(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::CustomCall( - const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape) { - OpRequest op_request; - CustomCallRequest* request = op_request.mutable_custom_call_request(); - request->set_call_target_name(call_target_name); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - *request->mutable_shape() = shape; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::HostCompute( - tensorflow::gtl::ArraySlice operands, - const string& channel_name, int64 cost_estimate_ns, const Shape& shape) { - OpRequest op_request; - HostComputeRequest* request = op_request.mutable_host_compute_request(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - *request->mutable_shape() = shape; - request->set_channel_name(channel_name); - request->set_cost_estimate_ns(cost_estimate_ns); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Complex( - const ComputationDataHandle& real, const ComputationDataHandle& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_COMPLEX, real, imag, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Conj( - const ComputationDataHandle& operand) { - return Complex(Real(operand), Neg(Imag(operand))); -} - -ComputationDataHandle ComputationBuilder::Add( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_ADD, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Sub( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SUB, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Mul( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_MUL, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Div( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_DIV, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Rem( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_REM, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Max( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_MAX, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Min( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_MIN, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::And( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_AND, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Or( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_OR, lhs, rhs, broadcast_dimensions); -} - -// TODO(b/65209188): Create a dedicated lowering for Xor -ComputationDataHandle ComputationBuilder::Xor( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return Or(And(Not(lhs), rhs, broadcast_dimensions), - And(lhs, Not(rhs), broadcast_dimensions)); -} - -ComputationDataHandle ComputationBuilder::Not( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_NOT, operand); -} - -ComputationDataHandle ComputationBuilder::ShiftLeft( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SHIFT_LEFT, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::ShiftRightArithmetic( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SHIFT_RIGHT_ARITHMETIC, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::ShiftRightLogical( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SHIFT_RIGHT_LOGICAL, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Abs( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_ABS, operand); -} - -ComputationDataHandle ComputationBuilder::Atan2( - const ComputationDataHandle& y, const ComputationDataHandle& x, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_ATAN2, y, x, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Exp( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_EXP, operand); -} - -ComputationDataHandle ComputationBuilder::Floor( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_FLOOR, operand); -} - -ComputationDataHandle ComputationBuilder::Ceil( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_CEIL, operand); -} - -ComputationDataHandle ComputationBuilder::Round( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_ROUND_NEAREST_AFZ, operand); -} - -ComputationDataHandle ComputationBuilder::Log( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_LOG, operand); -} - -ComputationDataHandle ComputationBuilder::Sign( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_SIGN, operand); -} - -ComputationDataHandle ComputationBuilder::Cos( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_COS, operand); -} - -ComputationDataHandle ComputationBuilder::Sin( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_SIN, operand); -} - -ComputationDataHandle ComputationBuilder::Tanh( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_TANH, operand); -} - -ComputationDataHandle ComputationBuilder::Real( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_REAL, operand); -} - -ComputationDataHandle ComputationBuilder::Imag( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_IMAG, operand); -} - -ComputationDataHandle ComputationBuilder::IsFinite( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_IS_FINITE, operand); -} - -ComputationDataHandle ComputationBuilder::Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation) { - OpRequest op_request; - TransposeRequest* request = op_request.mutable_transpose_request(); - *request->mutable_operand() = operand; - for (int64 dimension : permutation) { - request->add_dimensions(dimension); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Rev( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - OpRequest op_request; - ReverseRequest* request = op_request.mutable_reverse_request(); - *request->mutable_operand() = operand; - for (int64 dimension : dimensions) { - request->add_dimensions(dimension); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Sort( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_SORT, operand); -} - -ComputationDataHandle ComputationBuilder::SqrtF32( - const ComputationDataHandle& operand) { - return BinaryOp(BINOP_POW, operand, ConstantR0(0.5), - /*broadcast_dimensions=*/{}); -} - -ComputationDataHandle ComputationBuilder::Pow( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_POW, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::ConvertElementType( - const ComputationDataHandle& operand, PrimitiveType new_element_type) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape_status = GetShape(operand); - if (!shape_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr original = shape_status.ConsumeValueOrDie(); - - OpRequest op_request; - ConvertRequest* request = op_request.mutable_convert_request(); - *request->mutable_operand() = operand; - request->set_new_element_type(new_element_type); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BitcastConvertType( - const ComputationDataHandle& operand, PrimitiveType new_element_type) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape_status = GetShape(operand); - if (!shape_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr original = shape_status.ConsumeValueOrDie(); - - OpRequest op_request; - ConvertRequest* request = op_request.mutable_bitcast_convert_request(); - *request->mutable_operand() = operand; - request->set_new_element_type(new_element_type); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::SquareF32( - const ComputationDataHandle& operand) { - return BinaryOp(BINOP_POW, operand, ConstantR0(2.0), - /*broadcast_dimensions=*/{}); -} - -ComputationDataHandle ComputationBuilder::ReciprocalF32( - const ComputationDataHandle& operand) { - return BinaryOp(BINOP_POW, operand, ConstantR0(-1.0), - /*broadcast_dimensions=*/{}); -} - -ComputationDataHandle ComputationBuilder::Neg( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_NEGATE, operand); -} - -ComputationDataHandle ComputationBuilder::Clz( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_CLZ, operand); -} - -ComputationDataHandle ComputationBuilder::Clamp( - const ComputationDataHandle& min, const ComputationDataHandle& operand, - const ComputationDataHandle& max) { - return TernaryOp(TRIOP_CLAMP, min, operand, max); -} - -ComputationDataHandle ComputationBuilder::UnaryOp( - UnaryOperation unop, const ComputationDataHandle& operand) { - OpRequest op_request; - UnaryOpRequest* request = op_request.mutable_unary_op_request(); - request->set_unop(unop); - *request->mutable_operand() = operand; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BinaryOp( - BinaryOperation binop, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - OpRequest op_request; - BinaryOpRequest* request = op_request.mutable_binary_op_request(); - request->set_binop(binop); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - for (int64 dimension : broadcast_dimensions) { - request->add_broadcast_dimensions(dimension); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::RngOp( - RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters, - const Shape& shape) { - OpRequest op_request; - RngRequest* request = op_request.mutable_rng_request(); - request->set_distribution(distribution); - for (const ComputationDataHandle& param : parameters) { - *request->add_parameter() = param; - } - *request->mutable_shape() = shape; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::TernaryOp( - TernaryOperation triop, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, const ComputationDataHandle& ehs) { - OpRequest op_request; - TernaryOpRequest* request = op_request.mutable_ternary_op_request(); - request->set_triop(triop); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - *request->mutable_ehs() = ehs; - return RunOpAndParseResponse(&op_request); -} - -Status ComputationBuilder::SetReturnValue( - const ComputationDataHandle& operand) { - TF_RETURN_IF_ERROR(first_error_); - - SetReturnValueRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - - SetReturnValueResponse response; - - VLOG(2) << "making set-handle-to-execute request"; - Status s = client_->stub()->SetReturnValue(&request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - NoteError(s); - return first_error_; - } - - return Status::OK(); -} - -StatusOr ComputationBuilder::IsConstant( - const ComputationDataHandle& operand, int64 num_parameters) { - TF_RETURN_IF_ERROR(first_error_); - - IsConstantRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - request.set_num_parameters(num_parameters); - IsConstantResponse response; - - VLOG(2) << "making IsConstant request"; - Status s = client_->stub()->IsConstant(&request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - return s; - } - return response.is_constant(); -} - -StatusOr> ComputationBuilder::ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice parameters) { - TF_RETURN_IF_ERROR(first_error_); - - ComputeConstantRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - if (output_layout != nullptr) { - *request.mutable_output_layout() = *output_layout; - } - for (const auto& param : parameters) { - *request.add_parameters() = param.ToProto(); - } - - ComputeConstantResponse response; - - VLOG(2) << "making compute-constant request"; - Status s = client_->stub()->ComputeConstant(&request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - return s; - } - - VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}"; - - if (!response.has_literal()) { - return InternalError( - "no computed literal in the provided response in ComputeConstant " - "request"); - } - return Literal::CreateFromProto(response.literal()); -} - -ComputationDataHandle ComputationBuilder::Map( - tensorflow::gtl::ArraySlice operands, - const Computation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { - OpRequest op_request; - MapRequest* request = op_request.mutable_map_request(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - *request->mutable_to_apply() = computation.handle(); - for (int64 dimension : dimensions) { - request->add_dimensions(dimension); - } - for (const ComputationDataHandle& sop : static_operands) { - *request->add_static_operands() = sop; - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::RngNormal( - const ComputationDataHandle& mu, const ComputationDataHandle& sigma, - const Shape& shape) { - return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape); -} - -ComputationDataHandle ComputationBuilder::RngUniform( - const ComputationDataHandle& a, const ComputationDataHandle& b, - const Shape& shape) { - return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape); -} - -ComputationDataHandle ComputationBuilder::While( - const Computation& condition, const Computation& body, - const ComputationDataHandle& init) { - OpRequest op_request; - WhileRequest* request = op_request.mutable_while_request(); - *request->mutable_condition() = condition.handle(); - *request->mutable_body() = body.handle(); - *request->mutable_init() = init; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Gather( - const ComputationDataHandle& input, - const ComputationDataHandle& gather_indices, - const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds) { - OpRequest op_request; - GatherRequest* gather_request = op_request.mutable_gather_request(); - *gather_request->mutable_input() = input; - *gather_request->mutable_gather_indices() = gather_indices; - *gather_request->mutable_dimension_numbers() = dimension_numbers; - for (int64 window_bound : window_bounds) { - gather_request->add_window_bounds(window_bound); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Conditional( - const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const Computation& true_computation, - const ComputationDataHandle& false_operand, - const Computation& false_computation) { - OpRequest op_request; - ConditionalRequest* request = op_request.mutable_conditional_request(); - *request->mutable_predicate() = predicate; - *request->mutable_true_operand() = true_operand; - *request->mutable_true_computation() = true_computation.handle(); - *request->mutable_false_operand() = false_operand; - *request->mutable_false_computation() = false_computation.handle(); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { - OpRequest op_request; - ReduceRequest* request = op_request.mutable_reduce_request(); - *request->mutable_operand() = operand; - *request->mutable_init_value() = init_value; - for (int64 dimension : dimensions_to_reduce) { - request->add_dimensions(dimension); - } - *request->mutable_to_apply() = computation.handle(); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::ReduceAll( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - - std::vector all_dimnos(ShapeUtil::Rank(*shape.ValueOrDie())); - std::iota(all_dimnos.begin(), all_dimnos.end(), 0); - return Reduce(operand, init_value, computation, all_dimnos); -} - -ComputationDataHandle ComputationBuilder::ReduceWindow( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - - Status padding_valid = - ValidatePaddingValues(AsInt64Slice(shape.ValueOrDie()->dimensions()), - window_dimensions, window_strides); - if (!padding_valid.ok()) { - first_error_ = padding_valid; - return ComputationDataHandle(); - } - - std::vector> padding_values = - MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()), - window_dimensions, window_strides, padding); - return ReduceWindowWithGeneralPadding(operand, init_value, computation, - window_dimensions, window_strides, - padding_values); -} - -ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { - OpRequest op_request; - ReduceWindowRequest* request = op_request.mutable_reduce_window_request(); - *request->mutable_operand() = operand; - *request->mutable_to_apply() = computation.handle(); - *request->mutable_init_value() = init_value; - - if (!MakeWindow(window_dimensions, window_strides, padding, {}, {}, - request->mutable_window())) { - NoteError(InternalError("failed to make window")); - return ComputationDataHandle(); - } - - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BatchNormTraining( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& offset, float epsilon, int64 feature_index) { - OpRequest op_request; - BatchNormTrainingRequest* request = - op_request.mutable_batch_norm_training_request(); - *request->mutable_operand() = operand; - *request->mutable_scale() = scale; - *request->mutable_offset() = offset; - request->set_epsilon(epsilon); - request->set_feature_index(feature_index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BatchNormInference( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& offset, const ComputationDataHandle& mean, - const ComputationDataHandle& variance, float epsilon, int64 feature_index) { - OpRequest op_request; - BatchNormInferenceRequest* request = - op_request.mutable_batch_norm_inference_request(); - *request->mutable_operand() = operand; - *request->mutable_scale() = scale; - *request->mutable_offset() = offset; - *request->mutable_mean() = mean; - *request->mutable_variance() = variance; - request->set_epsilon(epsilon); - request->set_feature_index(feature_index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BatchNormGrad( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& batch_mean, - const ComputationDataHandle& batch_var, - const ComputationDataHandle& grad_output, float epsilon, - int64 feature_index) { - OpRequest op_request; - BatchNormGradRequest* request = op_request.mutable_batch_norm_grad_request(); - *request->mutable_operand() = operand; - *request->mutable_scale() = scale; - *request->mutable_mean() = batch_mean; - *request->mutable_variance() = batch_var; - *request->mutable_grad_output() = grad_output; - request->set_epsilon(epsilon); - request->set_feature_index(feature_index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::CrossReplicaSum( - const ComputationDataHandle& operand) { - OpRequest op_request; - CrossReplicaSumRequest* request = - op_request.mutable_cross_replica_sum_request(); - *request->mutable_operand() = operand; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::SelectAndScatter( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - return SelectAndScatterWithGeneralPadding( - operand, select, window_dimensions, window_strides, - MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()), - window_dimensions, window_strides, padding), - source, init_value, scatter); -} - -ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter) { - OpRequest op_request; - SelectAndScatterRequest* request = - op_request.mutable_select_and_scatter_request(); - *request->mutable_operand() = operand; - *request->mutable_select() = select.handle(); - *request->mutable_source() = source; - *request->mutable_init_value() = init_value; - *request->mutable_scatter() = scatter.handle(); - - if (!MakeWindow(window_dimensions, window_strides, padding, {}, {}, - request->mutable_window())) { - NoteError(InternalError("failed to make window")); - return ComputationDataHandle(); - } - - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::ReducePrecision( - const ComputationDataHandle& operand, const int exponent_bits, - const int mantissa_bits) { - OpRequest op_request; - ReducePrecisionRequest* request = - op_request.mutable_reduce_precision_request(); - *request->mutable_operand() = operand; - request->set_exponent_bits(exponent_bits); - request->set_mantissa_bits(mantissa_bits); - return RunOpAndParseResponse(&op_request); -} - -void ComputationBuilder::Send(const ComputationDataHandle& operand, - const ChannelHandle& handle) { - OpRequest op_request; - SendRequest* request = op_request.mutable_send_request(); - *request->mutable_operand() = operand; - *request->mutable_channel_handle() = handle; - *op_request.mutable_computation() = computation_.handle(); - RunOpAndNoteError(&op_request); -} - -ComputationDataHandle ComputationBuilder::Recv(const Shape& shape, - const ChannelHandle& handle) { - OpRequest op_request; - RecvRequest* request = op_request.mutable_recv_request(); - *request->mutable_shape() = shape; - *request->mutable_channel_handle() = handle; - return RunOpAndParseResponse(&op_request); -} - -Computation ComputationBuilder::BuildAndNoteError() { - DCHECK(parent_builder_ != nullptr); - auto build_status = Build(); - if (!build_status.ok()) { - parent_builder_->NoteError( - AddStatus(build_status.status(), - tensorflow::strings::StrCat("error from: ", name_))); - return Computation(); - } - return build_status.ConsumeValueOrDie(); -} - -StatusOr ComputationBuilder::Build() { - if (!first_error_.ok()) { - string backtrace; - first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); - return AppendStatus(first_error_, backtrace); - } - - if (computation_.IsNull()) { - return FailedPrecondition("no computation was built"); - } - - return {std::move(computation_)}; -} - -/* static */ ConvolutionDimensionNumbers -ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { - ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_input_batch_dimension(kConvBatchDimension); - dimension_numbers.set_input_feature_dimension(kConvFeatureDimension); - dimension_numbers.set_output_batch_dimension(kConvBatchDimension); - dimension_numbers.set_output_feature_dimension(kConvFeatureDimension); - dimension_numbers.set_kernel_output_feature_dimension( - kConvKernelOutputDimension); - dimension_numbers.set_kernel_input_feature_dimension( - kConvKernelInputDimension); - for (int i = 0; i < num_spatial_dims; ++i) { - dimension_numbers.add_input_spatial_dimensions(i + 2); - dimension_numbers.add_kernel_spatial_dimensions(i + 2); - dimension_numbers.add_output_spatial_dimensions(i + 2); - } - return dimension_numbers; -} - -/* static */ StatusOr -ComputationBuilder::CreateConvDimensionNumbers( - int64 input_batch, int64 input_feature, int64 input_first_spatial, - int64 input_second_spatial, int64 output_batch, int64 output_feature, - int64 output_first_spatial, int64 output_second_spatial, - int64 kernel_output_feature, int64 kernel_input_feature, - int64 kernel_first_spatial, int64 kernel_second_spatial) { - if (std::set({input_batch, input_feature, input_first_spatial, - input_second_spatial}) - .size() != 4) { - return FailedPrecondition( - "dimension numbers for the input are not unique: (%lld, %lld, %lld, " - "%lld)", - input_batch, input_feature, input_first_spatial, input_second_spatial); - } - if (std::set({kernel_output_feature, kernel_input_feature, - kernel_first_spatial, kernel_second_spatial}) - .size() != 4) { - return FailedPrecondition( - "dimension numbers for the weight are not unique: (%lld, %lld, %lld, " - "%lld)", - kernel_output_feature, kernel_input_feature, kernel_first_spatial, - kernel_second_spatial); - } - if (std::set({output_batch, output_feature, output_first_spatial, - output_second_spatial}) - .size() != 4) { - return FailedPrecondition( - "dimension numbers for the output are not unique: (%lld, %lld, %lld, " - "%lld)", - output_batch, output_feature, output_first_spatial, - output_second_spatial); - } - ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_input_batch_dimension(input_batch); - dimension_numbers.set_input_feature_dimension(input_feature); - dimension_numbers.add_input_spatial_dimensions(input_first_spatial); - dimension_numbers.add_input_spatial_dimensions(input_second_spatial); - dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature); - dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature); - dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial); - dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial); - dimension_numbers.set_output_batch_dimension(output_batch); - dimension_numbers.set_output_feature_dimension(output_feature); - dimension_numbers.add_output_spatial_dimensions(output_first_spatial); - dimension_numbers.add_output_spatial_dimensions(output_second_spatial); - return dimension_numbers; -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h deleted file mode 100644 index ac1eb915cc52df94df71631a7e80de9095f7fafb..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ /dev/null @@ -1,1067 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/array.h" -#include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/array3d.h" -#include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/bitmap.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/stacktrace.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// Wraps an XLA client with a convenient interface for building up -// computations. Any errors encountered in building up the computation are -// deferred from being handled until Build() is called. -// -// Thread-compatible. -// -// TODO(b/74197823): Deprecated. Use XlaBuilder instead. -class ComputationBuilder { - public: - // client: client in which to build the computation. - // computation_name: name to use for the built computation. - ComputationBuilder(Client* client, const string& computation_name); - - ~ComputationBuilder(); - - // Returns the client the builder was initialized with. - Client* client() const { return client_; } - - // Returns the computation name. - const string& name() const { return name_; } - - // Sets OpMetadata that will be added to all instructions until cleared. - // - // OpMetadata is often applied to a series of XLA HLO instructions. As a - // result, OpMetadata is set on the Computation Builder. All subsequent - // instructions generated via this Computation Builder will have the same - // OpMetadata attached until a call to ClearOpMetadata. - void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } - - // Clears the HloMetadata state. - void ClearOpMetadata() { metadata_.Clear(); } - - // Sets an OpSharding that will be attached to all instructions until cleared. - void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } - - // Clears the sharding. Ops will be sharded according to the default placement - // policy. - void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; } - - // Returns the OpSharding that will be attached to all instructions. - const tensorflow::gtl::optional& sharding() const { - return sharding_; - } - - // Sets the builder to a mode where it will die immediately when an error is - // encountered, rather than producing it in a deferred fashion when Build() is - // called (which is the default). - void set_die_immediately_on_error(bool enabled) { - die_immediately_on_error_ = enabled; - } - - // Enqueues a "retrieve parameter value" instruction for a parameter that was - // passed to the computation. - ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape, - const string& name); - - // Retrieves the (inferred) shape of the operand in the computation. - StatusOr> GetShape( - const ComputationDataHandle& operand); - - // Retrieves the (inferred) result for the current computation's shape. - StatusOr GetProgramShape(); - - // Enqueues a constant with the value of the given literal onto the - // computation. - ComputationDataHandle ConstantLiteral(const Literal& literal); - - // Enqueues a constant onto the computation. Methods are templated on the - // native host type (NativeT) which corresponds to a specific XLA - // PrimitiveType as given in the following table: - // - // Native Type PrimitiveType - // ----------------------------- - // bool PRED - // int32 S32 - // int64 S64 - // uint32 U32 - // uint64 U64 - // float F32 - // double F64 - // - // Note: not all primitive types defined in xla_data.proto have a - // corresponding native type yet. - template - ComputationDataHandle ConstantR0(NativeT value); - template - ComputationDataHandle ConstantR1(tensorflow::gtl::ArraySlice values); - ComputationDataHandle ConstantR1(const tensorflow::core::Bitmap& values); - template - ComputationDataHandle ConstantR2( - std::initializer_list> values); - template - ComputationDataHandle ConstantFromArrayWithLayout( - const Array& values, const Layout& layout); - template - ComputationDataHandle ConstantFromArray(const Array& values); - template - ComputationDataHandle ConstantR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout); - template - ComputationDataHandle ConstantR2FromArray2D(const Array2D& values); - template - ComputationDataHandle ConstantR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout); - template - ComputationDataHandle ConstantR3FromArray3D(const Array3D& values); - template - ComputationDataHandle ConstantR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout); - template - ComputationDataHandle ConstantR4FromArray4D(const Array4D& values); - - // Enqueues a rank one constant (vector) onto the computation. The vector has - // size 'length' and every element has the value 'value'. - template - ComputationDataHandle ConstantR1(int64 length, NativeT value); - - // Adds dimensions to an array by duplicating the data in the array. - // - // The new dimensions are inserted on the left, i.e. if - // broadcast_sizes has values {a0, ..., aN} and the operand shape - // has dimensions {b0, ..., bM} then the shape of the output has - // dimensions {a0, ..., aN, b0, ..., bM}. - // - // The new dimensions index into copies of the operand, i.e. - // - // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] - ComputationDataHandle Broadcast( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); - - // Enqueues a pad operation onto the computation that pads the given value on - // the edges as well as between the elements of the input. padding_config - // specifies the padding amount for each dimension. - ComputationDataHandle Pad(const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config); - - // Enqueues an operation onto the computation that flattens the operand based - // on the dimension order (major/slowest-varying to minor/fastest-varying) - // given, followed by reshaping it into the shape with the given dimension - // sizes (also major to minor). Conceptually, this is a limited form of - // "shape casting". - ComputationDataHandle Reshape(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); - - // Enqueues an operation onto the computation that collapses the operand, from - // first to last dimension (C order), then reshapes it to the given dimension - // sizes. Conceptually, this is a limited form of "shape casting". - ComputationDataHandle Reshape(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice new_sizes); - - // Wrapper for Reshape. - // Enqueues an operation to collapse the provided dimensions; e.g. an - // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to - // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must - // be a consecutive, in-order subsequence of the operand dimensions. - // - // Note that collapsing a single dimension does nothing: - // - // {256} collapsing {0} => {256} - // {1} collapsing {0} => {1} - // - // Collapsing multiple dimensions produces a single result dimension: - // - // {256, 2} collapsing {0,1} => {512} - // {256, 2, 3} collapsing {0,1} => {512, 3} - // - // This could potentially cause data to be moved -- it provides a more - // structured form of reshaping than an arbitrary Reshape operation. - ComputationDataHandle Collapse(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); - - // Enqueues a slice operation onto the computation that slices the operand - // from the start indices to the limit indices; e.g. - // - // x - // [ 0 1 2 3 ] - // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] - // [ 8 9 a b ] - // - // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D - // range notation. - // The strides parameter determines the stride over the slice - ComputationDataHandle Slice(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); - - // Enqueues a slice operation in a given dimension, taking all other - // dimensions as they are; e.g. if dimno is 1 from start_index 2 to - // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand - // for: - // - // array[:, 2:4:1, :] - ComputationDataHandle SliceInDim(const ComputationDataHandle& operand, - int64 start_index, int64 limit_index, - int64 stride, int64 dimno); - - // Enqueues a slice operation onto the computation that slices the 'operand' - // from dynamic start indices which are passed in 'start_indices'. - // The size of the slice in each dimension is passed in 'slice_sizes', - // which specify the end point of exclusive slice intervals in each - // dimension [start, start + size). - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo input dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. - ComputationDataHandle DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); - - // Enqueues a dynamic update slice operation onto the computation, which - // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. - // The shape of 'update' determines the shape of the slice of 'operand' - // which is updated. - // The indices specified in 'start_indices' specify the offset of the slice - // of 'operand' which is updated. - // - // update = {10, 11} // calculated at runtime. - // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] - // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] - // [7 8 9] [7 8 9 ] - // - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo update dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. - ComputationDataHandle DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices); - - // Enqueues a concatenate instruction onto the computation. 'operands' must - // have >= 1 entry. - ComputationDataHandle ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension); - - // Enqueue a tracing operation onto the computation; the computation will emit - // a logging message with the operand. - void Trace(const string& tag, const ComputationDataHandle& operand); - - // Enqueues a conditional-move-like select operation onto the computation; - // predicated on pred, selects between on_true and on_false. - ComputationDataHandle Select(const ComputationDataHandle& pred, - const ComputationDataHandle& on_true, - const ComputationDataHandle& on_false); - - // Enqueues a tuple-creation instruction onto the computation. - ComputationDataHandle Tuple( - tensorflow::gtl::ArraySlice elements); - - // Enqueues a tuple-element-get instruction onto the computation. - ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data, - int64 index); - - // Enqueues an equal-to comparison instruction onto the computation. - ComputationDataHandle Eq( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a not-equal comparison instruction onto the computation. - ComputationDataHandle Ne( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a greater-or-equal comparison instruction onto the computation. - ComputationDataHandle Ge( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a greater-than comparison instruction onto the computation. - ComputationDataHandle Gt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a less-than comparison instruction onto the computation. - ComputationDataHandle Lt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a less-or-equal comparison instruction onto the computation. - ComputationDataHandle Le( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a dot instruction onto the computation. - ComputationDataHandle Dot(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs); - - // Enqueues a general dot instruction onto the computation. - ComputationDataHandle DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - const DotDimensionNumbers& dimension_numbers); - - // Default dimension numbers used for a 2D convolution. - static constexpr int64 kConvBatchDimension = 0; - static constexpr int64 kConvFeatureDimension = 1; - static constexpr int64 kConvFirstSpatialDimension = 2; - static constexpr int64 kConvSecondSpatialDimension = 3; - static constexpr int64 kConvKernelOutputDimension = 0; - static constexpr int64 kConvKernelInputDimension = 1; - static constexpr int64 kConvKernelFirstSpatialDimension = 2; - static constexpr int64 kConvKernelSecondSpatialDimension = 3; - - // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for - // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for - // the kernel operand - // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. - static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( - int num_spatial_dims = 2); - - // Creates a ConvolutionDimensionNumbers with the given arguments. Returns an - // error if either the input or the weight dimension numbers have conflicts. - static StatusOr CreateConvDimensionNumbers( - int64 input_batch, int64 input_feature, int64 input_first_spatial, - int64 input_second_spatial, int64 output_batch, int64 output_feature, - int64 output_first_spatial, int64 output_second_spatial, - int64 kernel_output_feature, int64 kernel_input_feature, - int64 kernel_first_spatial, int64 kernel_second_spatial); - - // Enqueues a convolution instruction onto the computation, which uses the - // default convolution dimension numbers. - ComputationDataHandle Conv(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration in the format returned by MakePadding(). - ComputationDataHandle ConvWithGeneralPadding( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided dimension numbers configuration. - ComputationDataHandle ConvWithGeneralDimensions( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration as well as the dimension numbers. - ComputationDataHandle ConvGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration, dilation factors and dimension numbers. - ComputationDataHandle ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); - - // Enqueues an FFT instruction onto the computation, of the given type and - // with the given FFT length. - ComputationDataHandle Fft(const ComputationDataHandle& operand, - FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); - - // Enqueues an infeed instruction onto the computation, which writes data of - // the given shape to the infeed buffer of the device. - ComputationDataHandle Infeed(const Shape& shape, const string& config = ""); - - // Enqueues an outfeed instruction onto the computation. This instruction - // generates outgoing data transfers for the given data. - // - // shape_with_layout communicates the laid out shape that we want to outfeed - // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error - // will occur. - void Outfeed(const ComputationDataHandle& operand, - const Shape& shape_with_layout, const string& outfeed_config); - - // Enqueues a call instruction onto the computation. - ComputationDataHandle Call( - const Computation& computation, - tensorflow::gtl::ArraySlice operands); - - // Enqueues a custom call instruction onto the computation. - // During code generation, a call instruction is emitted which targets a - // symbol with the name |call_target_name|. The |operands| are passed to the - // call instruction. |shape| is the resultant shape. - ComputationDataHandle CustomCall( - const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape); - - // Enqueues a pseudo-op to represent host-side computation data-dependencies. - // During code generation, host send and receive operations will be generated - // to transfer |operands| to the host and a single result of |shape| back to - // the device. Host send/recv operations are emitted using |channel_name|. - // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO - // instruction scheduling. - ComputationDataHandle HostCompute( - tensorflow::gtl::ArraySlice operands, - const string& channel_name, int64 cost_estimate_ns, const Shape& shape); - - // The following methods enqueue element-wise binary arithmetic operations - // onto the computation. The shapes of the operands have to match unless one - // of the operands is a scalar, or an explicit broadcast dimension is given - // (see g3doc for more details). - - // Enqueues a complex compose instruction onto the computation. - ComputationDataHandle Complex( - const ComputationDataHandle& real, const ComputationDataHandle& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a complex conjugate instruction onto the computation. - ComputationDataHandle Conj(const ComputationDataHandle& operand); - - // Enqueues an add instruction onto the computation. - ComputationDataHandle Add( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a subtract instruction onto the computation. - ComputationDataHandle Sub( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a multiply instruction onto the computation. - ComputationDataHandle Mul( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a divide instruction onto the computation. - ComputationDataHandle Div( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a remainder instruction onto the computation. - ComputationDataHandle Rem( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a max instruction onto the computation. - ComputationDataHandle Max( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a min instruction onto the computation. - ComputationDataHandle Min( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Element-wise logical operators - ComputationDataHandle And( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - ComputationDataHandle Or( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - ComputationDataHandle Xor( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - ComputationDataHandle Not(const ComputationDataHandle& operand); - - ComputationDataHandle ShiftLeft( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - ComputationDataHandle ShiftRightArithmetic( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - ComputationDataHandle ShiftRightLogical( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Reduces an array among the provided dimensions, given "computation" as a - // reduction operator. - ComputationDataHandle Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); - - // Convenience wrapper around the above that reduces all the dimensions in the - // operand shape. - ComputationDataHandle ReduceAll(const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, - const Computation& computation); - - // Enqueues a windowed reduce instruction onto the computation. - ComputationDataHandle ReduceWindow( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding); - - // As ReduceWindow(), but the padding is given in the format - // returned by MakePadding(). - ComputationDataHandle ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); - - // Returns the sum of the operand value across all replicas. All replicas - // supply one input to the sum and all replicas receive the resulting sum. - ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand); - - // Enqueues an operation that scatters the `source` array to the selected - // indices of each window. - ComputationDataHandle SelectAndScatter( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter); - - // As SelectAndScatter(), but the padding is given in the format - // returned by MakePadding(). - ComputationDataHandle SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter); - - // Enqueues an abs instruction onto the computation. - ComputationDataHandle Abs(const ComputationDataHandle& operand); - - // Enqueues a atan2 instruction onto the computation. - ComputationDataHandle Atan2( - const ComputationDataHandle& y, const ComputationDataHandle& x, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues an exp instruction onto the computation. - ComputationDataHandle Exp(const ComputationDataHandle& operand); - - // Enqueues a floor instruction onto the computation. - ComputationDataHandle Floor(const ComputationDataHandle& operand); - - // Enqueues a ceil instruction onto the computation. - ComputationDataHandle Ceil(const ComputationDataHandle& operand); - - // Enqueues a round instruction onto the computation, rounding to nearest even - // with half-way cases rounding away from zero. - ComputationDataHandle Round(const ComputationDataHandle& operand); - - // Enqueues an log instruction (natural logarithm) onto the computation. - ComputationDataHandle Log(const ComputationDataHandle& operand); - - // Enqueues a sign instruction onto the computation. - ComputationDataHandle Sign(const ComputationDataHandle& operand); - - // Enqueues a cosine instruction onto the computation. - ComputationDataHandle Cos(const ComputationDataHandle& operand); - - // Enqueues a sine instruction onto the computation. - ComputationDataHandle Sin(const ComputationDataHandle& operand); - - // Enqueues a tanh instruction onto the computation. - ComputationDataHandle Tanh(const ComputationDataHandle& operand); - - // Enqueues a real-part instruction onto the computation. - ComputationDataHandle Real(const ComputationDataHandle& operand); - - // Enqueues an imaginary-part instruction onto the computation. - ComputationDataHandle Imag(const ComputationDataHandle& operand); - - // Enqueues a float32 sqrt instruction onto the computation. - // (float32 is specified as there is an implicit float32 0.5f constant - // exponent). - ComputationDataHandle SqrtF32(const ComputationDataHandle& operand); - - // Enqueues a float32 square instruction onto the computation. - // (float32 is specified as there is an implicit float32 2.0f constant - // exponent). - ComputationDataHandle SquareF32(const ComputationDataHandle& operand); - - // Enqueues a lhs^rhs computation onto the computation. - ComputationDataHandle Pow( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues an operator that tests if the operand's values are finite, i.e., - // not Inf or NaN. Defined only for floating-point types. Returns an array of - // booleans with the same shape where entries are true iff the corresponding - // entry was NaN. - ComputationDataHandle IsFinite(const ComputationDataHandle& operand); - - // Enqueues a convert instruction onto the computation that changes the - // element type of the operand array to primitive_type. - ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, - PrimitiveType new_element_type); - - // Enqueues a no-op instruction onto the computation that changes - // the element type of the operand array to primitive_type. The - // bit-widths of the source and destination element types must be - // identical. - ComputationDataHandle BitcastConvertType(const ComputationDataHandle& operand, - PrimitiveType new_element_type); - - // Enqueues a float32 reciprocal instruction onto the computation. - // (float32 is specified as there is an implicit float32 -1.0f constant - // exponent). - // - // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the - // shape of the operand. - ComputationDataHandle ReciprocalF32(const ComputationDataHandle& operand); - - // Enqueues a negate instruction onto the computation. - ComputationDataHandle Neg(const ComputationDataHandle& operand); - - // Enqueues a count-leading-zeros instruction onto the computation. - ComputationDataHandle Clz(const ComputationDataHandle& operand); - - // Enqueues a transpose instruction onto the computation. - ComputationDataHandle Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation); - - // Enqueues a reverse instruction onto the computation. The order of the - // elements in the given dimensions is reversed (i.e., the element at index i - // is moved to index dimension_size - 1 - i). - ComputationDataHandle Rev(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); - - // Enqueues a sort (as increasing order) instruction onto the computation. - ComputationDataHandle Sort(const ComputationDataHandle& operand); - - // Enqueues a clamp instruction onto the computation. - ComputationDataHandle Clamp(const ComputationDataHandle& min, - const ComputationDataHandle& operand, - const ComputationDataHandle& max); - - // Enqueues a map instruction onto the computation. - ComputationDataHandle Map( - tensorflow::gtl::ArraySlice operands, - const Computation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands = {}); - - // Enqueues a N(mu, sigma) random number generation instruction onto the - // computation. - ComputationDataHandle RngNormal(const ComputationDataHandle& mu, - const ComputationDataHandle& sigma, - const Shape& shape); - - // Enqueues a U(a, b) random number generation instruction onto the - // computation. Returns values in the semi-open interval [a, b). - ComputationDataHandle RngUniform(const ComputationDataHandle& a, - const ComputationDataHandle& b, - const Shape& shape); - - // Enqueues a while node onto the computation. - ComputationDataHandle While(const Computation& condition, - const Computation& body, - const ComputationDataHandle& init); - - // Enqueues a conditional node onto the computation. - ComputationDataHandle Conditional(const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const Computation& true_computation, - const ComputationDataHandle& false_operand, - const Computation& false_computation); - - // Enqueues a ReducePrecision node onto the computation. - ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand, - const int exponent_bits, - const int mantissa_bits); - - // Enqueues a Gather node onto the computation. - ComputationDataHandle Gather( - const ComputationDataHandle& input, - const ComputationDataHandle& gather_indices, - const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds); - - // Enqueues a Send node onto the computation, to send the given operand to - // a Recv instruction that shares the same channel handle. - void Send(const ComputationDataHandle& operand, const ChannelHandle& handle); - - // Enqueues a Recv node onto the computation. The data comes from a Send - // instruction that shares the same channel handle and its shape must - // be the same as the given shape. - ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle); - - // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on parameters with index greater than or equal to - // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`. - // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a - // compile-time constant without evaluating the computation. - StatusOr IsConstant(const ComputationDataHandle& operand, - int64 num_parameters = 0); - - // Normalizes operand across spatial and batch dimensions for each feature. - // - // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` - // is the normalized result and batch_mean and batch_var are the mean and - // variance, respectively, across batch for the operand. - ComputationDataHandle BatchNormTraining(const ComputationDataHandle& operand, - const ComputationDataHandle& scale, - const ComputationDataHandle& offset, - float epsilon, int64 feature_index); - - // Normalizes operand across spatial and batch dimensions for each feature. - // - // `BatchNormInference` is equivalent to calling `BatchNormTraining` without - // computing `mean` and `variance` for each batch inside the operation. It - // uses the input `mean` and `variance` instead as estimated values. The - // purpose of this op is to reduce latency in inference, hence the name - // `BatchNormInference`. - // - // The output has the same shape as `operand`, and contains the normalized - // values for each batch. - ComputationDataHandle BatchNormInference( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& offset, const ComputationDataHandle& mean, - const ComputationDataHandle& variance, float epsilon, - int64 feature_index); - - // Calculates the gradients of a batch norm op. - // - // The inputs `batch_mean` and `batch_var` represent the mean and variance - // across the batch. - // - // Returns a tuple of three elements: - // - grad_operand: Gradient with respect to input `operand` - // - grad_offset: Gradient with respect to input `offset` - // - grad_scale: Gradient with respect to input `scale` - ComputationDataHandle BatchNormGrad(const ComputationDataHandle& operand, - const ComputationDataHandle& scale, - const ComputationDataHandle& batch_mean, - const ComputationDataHandle& batch_var, - const ComputationDataHandle& grad_output, - float epsilon, int64 feature_index); - - // Computes the value of a constant indicated by a - // ComputationDataHandle using a non-optimized interpreter on the host. - // - // The operand must be from the computation currently being built - - // i.e., returned from this builder with no intervening call to - // Build(). This happens to currently work regardless of that, but - // that may stop working at any time. - // - // The operand must represent a constant value, which in this case - // means that it must not statically depend on any parameter of the - // computation that is being built other then the ones specified on the - // parameter list. The parameters in the list will be indexed by their - // parameter id property so the number of parameters specified should be at - // least as many as the largest used parameter index. - // - // `IsConstant` can be used to test whether a computation is a compile-time - // constant without evaluation it. `ComputeConstant` only succeeds for - // computations where `IsConstant` returns true. - // - // This functionality can be useful when translating a computation - // into XLA where something that looked dynamic is required by - // XLA to be specified as a constant. E.g. the source - // computation (outside of XLA) may include a dynamic - // computation of the shape of something and ComputeConstant lets - // you determine what the value of that computation is in the case - // where the value can be determined at compile time. - // - // If output_layout is non-null, then the output of the computation - // will be stored using that layout. - StatusOr> ComputeConstant( - const ComputationDataHandle& operand, - const Layout* output_layout = nullptr, - tensorflow::gtl::ArraySlice parameters = {}); - - // Returns a new ComputationBuilder whose resultant Computation is used only - // by this ComputationBuilder. The sub-ComputationBuilder has the same - // die_immediately_on_error behavior as the parent. - std::unique_ptr CreateSubBuilder( - const string& computation_name); - - // Modifies the computation being built so that executions of it - // will return the value associated with operand, rather than the - // last expression enqueued on the ComputationBuilder. Any subsequent - // operations added to the ComputationBuilder will not have any effect unless - // SetReturnValue is called again. - Status SetReturnValue(const ComputationDataHandle& operand); - - // Builds the computation with the requested operations, or returns a non-ok - // status. - StatusOr Build(); - - // Builds the computation with the requested operations, or notes an error in - // the parent ComputationBuilder and returns an empty computation if building - // failed. This function is intended to be used where the returned - // Computation is only used by the parent ComputationBuilder and hence further - // operation on the returned Computation will simply be error'ed out if an - // error occurred while building this computation. If the built computation is - // to be used by a ComputationBuilder other than the parent ComputationBuilder - // then Build() should be used instead. - Computation BuildAndNoteError(); - - // Returns the first error that was encountered while building the - // computation. When an error is encountered, by default we return a vacuous - // ComputationDataHandle and inform the user of the error that occurred while - // building the computation when they make a final call to Build(). - // - // See also set_die_immediately_on_error(). - Status first_error() const { return first_error_; } - - private: - // Limited checking of convolution parameters. Returns false on - // error. - bool VerifyConvolution(const Shape& lhs_shape, const Shape& rhs_shape, - const ConvolutionDimensionNumbers& dimension_numbers); - - // The parent ComputationBuilder of a sub-ComputationBuilder. The - // parent_builder_ will be the nullptr if not a sub-ComputationBuilder. - ComputationBuilder* parent_builder_{nullptr}; - - // Helper function for creating a Window proto from user-supplied - // data. Returns true if the user-supplied data was valid. - bool MakeWindow(tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - Window* window); - - // Internal helper method that does the building for an arbitrary unary op. - ComputationDataHandle UnaryOp(UnaryOperation unop, - const ComputationDataHandle& operand); - - // Internal helper method that does the building for an arbitrary binary op. - // broadcast_dimensions specifies which dimensions to use for broadcasting - // when the operation is between tensors of different ranks. - ComputationDataHandle BinaryOp( - BinaryOperation binop, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); - - // Internal helper method that does the building for an arbitrary ternary op. - ComputationDataHandle TernaryOp(TernaryOperation triop, - const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - const ComputationDataHandle& ehs); - - // Internal helper method that does the building for a random number generator - // of a given distribution with an explicitly specified shape. - ComputationDataHandle RngOp( - RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters, - const Shape& shape); - - // Populates computation_ with a valid object or returns a failing status. - // This is used before any given operation is enqueued. - Status PrepareComputation(); - - // Notes that the error occurred by: - // * storing it internally and capturing a backtrace if it's the first error - // (this deferred value will be produced on the call to Build()) - // * dying if die_immediately_on_error_ is true - void NoteError(const Status& error); - - // Helper function that runs the given op_request, filling in op_response. - // Before the op is run, PrepareComputation is called, and common fields in - // the op_request are filled in. - Status RunOp(OpRequest* op_request, OpResponse* op_response); - - // Helper function that calls RunOp and calls NoteError on failures. - void RunOpAndNoteError(OpRequest* op_request); - - // Helper function that calls RunOp and either returns the output computation - // data handle (on success) or a vacuous computation data handle (on failure). - ComputationDataHandle RunOpAndParseResponse(OpRequest* op_request); - - // Helper function that implements GetShape without noting errors. This makes - // it easier to ensure the real GetShape will note errors on every error path. - StatusOr> GetShapeWithoutNoteError( - const ComputationDataHandle& operand); - - string name_; // Name to use for the built computation. - - // The first error encountered while building the computation. - // This is OK until the first error is encountered. - Status first_error_; - - // The saved stack trace from the point at which the first error occurred. - tensorflow::SavedStackTrace first_error_backtrace_; - - // The computation that operations are enqueued onto. - Computation computation_; - - // The client that the computation is created in. Not owned. - Client* client_; - - // Mode bit that indicates whether to die when a first error is encountered. - bool die_immediately_on_error_ = false; - - // The metadata to attach to each op. This is structured as a "modal"-like - // operation, in order to simplify client code (and not sprinkle this metadata - // throughout the TensorFlow op kernel implementations). - OpMetadata metadata_; - - // Sharding for this operator. This is structured as a "model"-like operation, - // in order to simplify client code, similar to metadata_. - tensorflow::gtl::optional sharding_; - - TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder); -}; - -template -ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) { - return ConstantLiteral(*Literal::CreateR0(value)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR1( - tensorflow::gtl::ArraySlice values) { - return ConstantLiteral(*Literal::CreateR1(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR1(int64 length, - NativeT value) { - Literal literal(ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), {length})); - literal.PopulateWithValue(value); - return ConstantLiteral(literal); -} - -inline ComputationDataHandle ComputationBuilder::ConstantR1( - const tensorflow::core::Bitmap& values) { - return ConstantLiteral(*Literal::CreateR1(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR2( - std::initializer_list> values) { - return ConstantLiteral(*Literal::CreateR2(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout( - const Array& values, const Layout& layout) { - return ConstantLiteral( - *Literal::CreateFromArrayWithLayout(values, layout)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantFromArray( - const Array& values) { - return ConstantLiteral(*Literal::CreateFromArray(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { - return ConstantLiteral( - *Literal::CreateFromArrayWithLayout(values, layout)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D( - const Array2D& values) { - return ConstantLiteral(*Literal::CreateR2FromArray2D(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout) { - return ConstantLiteral( - *Literal::CreateR3FromArray3DWithLayout(values, layout)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D( - const Array3D& values) { - return ConstantFromArray(values); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout) { - return ConstantFromArrayWithLayout(values, layout); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D( - const Array4D& values) { - return ConstantFromArray(values); -} - -// RAII-style object: sets the current sharding assignment in builder on -// construction, and sets back to the previous assignment on destruction. -class ScopedShardingAssignment { - public: - ScopedShardingAssignment(xla::ComputationBuilder* builder, - tensorflow::gtl::optional sharding) - : builder_(builder), prev_sharding_(builder->sharding()) { - SetSharding(sharding); - } - - ~ScopedShardingAssignment() { SetSharding(prev_sharding_); } - - private: - void SetSharding(const tensorflow::gtl::optional& sharding) { - if (sharding.has_value()) { - builder_->SetSharding(sharding.value()); - } else { - builder_->ClearSharding(); - } - } - - xla::ComputationBuilder* const builder_; - tensorflow::gtl::optional prev_sharding_; - - TF_DISALLOW_COPY_AND_ASSIGN(ScopedShardingAssignment); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/global_data.cc b/tensorflow/compiler/xla/client/global_data.cc index 40f59eaa68ebeb47edbd2afbeabad0cd2623ebc6..2986d4060013703873b2cffb6aacbb012606d16f 100644 --- a/tensorflow/compiler/xla/client/global_data.cc +++ b/tensorflow/compiler/xla/client/global_data.cc @@ -31,7 +31,7 @@ GlobalData::~GlobalData() { *request.mutable_data() = handle_; UnregisterResponse response; VLOG(1) << "requesting to unregister " << handle_.ShortDebugString(); - tensorflow::Status s = parent_->Unregister(&request, &response); + Status s = parent_->Unregister(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 9cd87f74735ff50df8a3382723c7d045ff6c9e52..3380af9f303b1dc2cec09aa37410ec40cdeaa526 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -92,21 +92,6 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, return MakeFakeDataViaDeviceOrDie(shape, client); } -std::vector> MakeFakeArgumentsOrDie( - const Computation& computation, Client* client) { - auto program_shape = - client->GetComputationShape(computation).ConsumeValueOrDie(); - - // For every (unbound) parameter that the computation wants, we manufacture - // some arbitrary data so that we can invoke the computation. - std::vector> fake_arguments; - for (const Shape& parameter : program_shape->parameters()) { - fake_arguments.push_back(MakeFakeDataOrDie(parameter, client)); - } - - return fake_arguments; -} - std::vector> MakeFakeArgumentsOrDie( const XlaComputation& computation, Client* client) { CHECK(computation.proto().has_program_shape()) diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index 9e06141b1f13d24cd033b72e31ee3a0442fe6a37..dc613099e2b42a60d0c11a654ab5cd41f8bd4f6f 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -32,12 +32,6 @@ namespace xla { std::unique_ptr MakeFakeDataOrDie(const Shape& shape, Client* client); -// Returns vector of GlobalData handles of fake data (created using -// MakeFakeDataOrDie) that are correctly shaped arguments for the given -// computation. -std::vector> MakeFakeArgumentsOrDie( - const Computation& computation, Client* client); - // Returns vector of GlobalData handles of fake data (created using // MakeFakeDataOrDie) that are correctly shaped arguments for the given // xla computation. diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 1acc6f86860e526b5ff737c45041a863f21da145..a7c55c6b2b7fe2b5541ce71bf3eaa24114522fc5 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -48,7 +48,7 @@ LocalExecutable::LocalExecutable(std::unique_ptr executable, << "Must have a valid device ordinal that the executable was built for."; } -tensorflow::Status LocalExecutable::ValidateExecutionOptions( +Status LocalExecutable::ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& run_options, const Backend& backend) { const ComputationLayout& host_computation_layout = @@ -207,7 +207,7 @@ StatusOr LocalExecutable::ExecuteAndDump( return std::move(result); } -tensorflow::Status LocalExecutable::RecordArguments( +Status LocalExecutable::RecordArguments( const tensorflow::gtl::ArraySlice arguments, SessionModule* session_module) { session_module->clear_arguments(); @@ -219,8 +219,8 @@ tensorflow::Status LocalExecutable::RecordArguments( return Status::OK(); } -tensorflow::Status LocalExecutable::RecordResult( - const ShapedBuffer* result, SessionModule* session_module) { +Status LocalExecutable::RecordResult(const ShapedBuffer* result, + SessionModule* session_module) { session_module->clear_result(); TF_ASSIGN_OR_RETURN(std::unique_ptr literal, LiteralFromShapedBuffer(*result)); @@ -261,25 +261,6 @@ Backend* LocalClient::mutable_backend() { return local_service_->mutable_backend(); } -StatusOr> LocalClient::Compile( - const Computation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& options) { - ExecutableBuildOptions updated_options = options; - if (options.device_ordinal() == -1) { - updated_options.set_device_ordinal(default_device_ordinal()); - VLOG(3) << "Set device ordinal to default value of: " - << updated_options.device_ordinal(); - } - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - local_service_->CompileExecutable(computation.handle(), argument_layouts, - updated_options)); - return WrapUnique(new LocalExecutable(std::move(executable), - local_service_->mutable_backend(), - updated_options)); -} - StatusOr> LocalClient::Compile( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index d8fd7a5623d1fecdcff6851aa3e3538822fb50da..d63d4ec7f3744d507cc854213e430e25e861e559 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" @@ -59,7 +58,7 @@ class LocalExecutable { // Validates that the given arguments and options satisfy various constraints // of the computation. - tensorflow::Status ValidateExecutionOptions( + Status ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& run_options, const Backend& backend); @@ -71,13 +70,13 @@ class LocalExecutable { // Records the arguments used to invoke the computation in a SessionModule // proto. - tensorflow::Status RecordArguments( + Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, SessionModule* session_module); // Records the result of the computation in a SessionModule proto. - tensorflow::Status RecordResult(const ShapedBuffer* result, - SessionModule* session_module); + Status RecordResult(const ShapedBuffer* result, + SessionModule* session_module); // Returns a literal containing the contents of the given ShapedBuffer. StatusOr> LiteralFromShapedBuffer( @@ -108,17 +107,8 @@ class LocalClient : public Client { LocalClient(const LocalClient&) = delete; void operator=(const LocalClient&) = delete; - // Build and return a LocalExecutable object. The executable is compiled using - // the given argument layouts and options. - StatusOr> Compile( - const Computation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& options); - // Build and return a LocalExecutable object. The executable is compiled using // the given XlaComputation, argument layouts and options. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> Compile( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 1899983e442116d3ebf8a3e79b0515653cd624cb..ae506317c2e4862d77cb4f0628e919871ad1aeb2 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -57,16 +57,6 @@ bool CanBeRoot(HloOpcode opcode) { } } -StatusOr> GetOperandShapes( - tensorflow::gtl::ArraySlice operands) { - std::vector operand_shapes; - for (const XlaOp& operand : operands) { - TF_ASSIGN_OR_RETURN(const Shape& shape, operand.GetShape()); - operand_shapes.push_back(shape); - } - return operand_shapes; -} - } // namespace StatusOr XlaBuilder::GetShape(const XlaOp& op) const { @@ -76,12 +66,14 @@ StatusOr XlaBuilder::GetShape(const XlaOp& op) const { return instr->shape(); } -StatusOr XlaOp::GetShape() const { - if (builder_ == nullptr) { - return InvalidArgument( - "cannot GetShape for an invalid XlaOp with handle %lld", handle()); +StatusOr> XlaBuilder::GetOperandShapes( + tensorflow::gtl::ArraySlice operands) const { + std::vector operand_shapes; + for (const XlaOp& operand : operands) { + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); + operand_shapes.push_back(shape); } - return builder_->GetShape(*this); + return operand_shapes; } XlaBuilder::XlaBuilder(const string& computation_name) @@ -286,7 +278,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, const XlaOp& operand) { TF_RETURN_IF_ERROR(first_error_); - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); CHECK(ShapeUtil::IsScalar(operand_shape) || ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)); @@ -325,7 +317,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferUnaryOpShape(unop, operand_shape)); return AddInstruction(std::move(instr), unop, {operand}); @@ -337,8 +329,8 @@ XlaOp XlaBuilder::BinaryOp( tensorflow::gtl::ArraySlice broadcast_dimensions) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape()); - TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferBinaryOpShape( binop, lhs_shape, rhs_shape, broadcast_dimensions)); @@ -374,12 +366,12 @@ XlaOp XlaBuilder::BinaryOp( updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs; } - TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, updated_lhs.GetShape()); + TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, GetShape(updated_lhs)); if (!ShapeUtil::SameDimensions(instr.shape(), updated_lhs_shape)) { TF_ASSIGN_OR_RETURN(updated_lhs, AddBroadcastSequence(instr.shape(), updated_lhs)); } - TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, updated_rhs.GetShape()); + TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, GetShape(updated_rhs)); if (!ShapeUtil::SameDimensions(instr.shape(), updated_rhs_shape)) { TF_ASSIGN_OR_RETURN(updated_rhs, AddBroadcastSequence(instr.shape(), updated_rhs)); @@ -393,9 +385,9 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, const XlaOp& ehs) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape()); - TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape()); - TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, ehs.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, GetShape(ehs)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferTernaryOpShape( triop, lhs_shape, rhs_shape, ehs_shape)); @@ -437,7 +429,7 @@ XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions); } -XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) { +XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = literal.shape(); @@ -485,7 +477,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, XlaOp XlaBuilder::Broadcast( const XlaOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { return NoteErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( const Shape& shape, ShapeInference::InferBroadcastShape(operand_shape, broadcast_sizes)); @@ -633,7 +625,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { return NoteErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(const Shape& shape, ShapeInference::InferReshapeShape( operand_shape, dimensions, new_sizes)); @@ -647,7 +639,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, XlaOp XlaBuilder::Reshape(const XlaOp& operand, tensorflow::gtl::ArraySlice new_sizes) { return NoteErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(auto shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand)); std::vector dimensions(shape.dimensions_size()); std::iota(dimensions.begin(), dimensions.end(), 0); return Reshape(operand, dimensions, new_sizes); @@ -1002,7 +994,7 @@ XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, const tensorflow::gtl::ArraySlice fft_length) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferFftShape(operand_shape, fft_type, fft_length)); @@ -1173,6 +1165,10 @@ XlaOp XlaBuilder::Exp(const XlaOp& operand) { return UnaryOp(HloOpcode::kExp, operand); } +XlaOp XlaBuilder::Expm1(const XlaOp& operand) { + return UnaryOp(HloOpcode::kExpm1, operand); +} + XlaOp XlaBuilder::Floor(const XlaOp& operand) { return UnaryOp(HloOpcode::kFloor, operand); } @@ -1189,6 +1185,10 @@ XlaOp XlaBuilder::Log(const XlaOp& operand) { return UnaryOp(HloOpcode::kLog, operand); } +XlaOp XlaBuilder::Log1p(const XlaOp& operand) { + return UnaryOp(HloOpcode::kLog1p, operand); +} + XlaOp XlaBuilder::Sign(const XlaOp& operand) { return UnaryOp(HloOpcode::kSign, operand); } @@ -1225,7 +1225,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand, tensorflow::gtl::ArraySlice permutation) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferTransposeShape(operand_shape, permutation)); @@ -1948,11 +1948,18 @@ StatusOr XlaBuilder::LookUpInstruction( const XlaOp& op) const { TF_RETURN_IF_ERROR(first_error_); + if (op.builder_ == nullptr) { + return InvalidArgument( + "invalid XlaOp with handle %lld; the builder of this op is freed", + op.handle()); + } if (op.builder_ != this) { - return InvalidArgument("invalid XlaOp with handle %lld", op.handle()); + return InvalidArgument( + "XlaOp with handle %lld is built by builder '%s', but is trying to use " + "it in builder '%s'", + op.handle(), op.builder_->name().c_str(), this->name().c_str()); } - TF_RET_CHECK(op.builder_ == this); if (op.handle() >= instructions_.size() || op.handle() < 0) { return InvalidArgument("no XlaOp value %lld", op.handle()); } diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index 4955f1515d66af00ddf72e4c7621292a590e662c..2b3013a91c488782098bd81994e899eae5a1f506 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -13,10 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// TODO(b/74197823): Replace computation_builder.h with this file. -// -// This is NOT YET ready to use. - #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ @@ -48,15 +44,11 @@ class XlaBuilder; // This represents an instruction that has been enqueued using the XlaBuilder. // This is used to pass to subsequent computations that depends upon the // instruction as an operand. -// -// TODO(b/74197823): Replace xla::ComputationDataHandle with this one. class XlaOp { public: XlaOp() : handle_(0), builder_(nullptr) {} ~XlaOp() {} - StatusOr GetShape() const; - const XlaBuilder* builder() const { return builder_; } bool operator==(const XlaOp& rhs) const { @@ -87,8 +79,6 @@ class XlaOp { // A convenient interface for building up computations. // // Thread-compatible. -// -// TODO(b/74197823): Replace xla::ComputationBuilder with this one. class XlaBuilder { public: // computation_name: name to use for the built computation. @@ -139,7 +129,7 @@ class XlaBuilder { // Enqueues a constant with the value of the given literal onto the // computation. - XlaOp ConstantLiteral(const Literal& literal); + XlaOp ConstantLiteral(const LiteralSlice& literal); // Enqueues a constant onto the computation. Methods are templated on the // native host type (NativeT) which corresponds to a specific XLA @@ -571,6 +561,9 @@ class XlaBuilder { // Enqueues an exp instruction onto the computation. XlaOp Exp(const XlaOp& operand); + // Enqueues an expm1 instruction onto the computation. + XlaOp Expm1(const XlaOp& operand); + // Enqueues a floor instruction onto the computation. XlaOp Floor(const XlaOp& operand); @@ -584,6 +577,9 @@ class XlaBuilder { // Enqueues an log instruction (natural logarithm) onto the computation. XlaOp Log(const XlaOp& operand); + // Enqueues an log1p instruction (log(x+1)) onto the computation. + XlaOp Log1p(const XlaOp& operand); + // Enqueues a sign instruction onto the computation. XlaOp Sign(const XlaOp& operand); @@ -847,6 +843,10 @@ class XlaBuilder { // computation and fills the root_id in the pointer. StatusOr GetProgramShape(int64* root_id) const; + // Returns shapes for the operands. + StatusOr> GetOperandShapes( + tensorflow::gtl::ArraySlice operands) const; + // A visitor which checks whether an operation is a compile-time constant, // meaning that it doesn't depend on any parameters, or on any stateful // operation such as `RngNormal` or `Infeed`. The visitor walks the @@ -981,8 +981,6 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { // RAII-style object: sets the current sharding assignment in builder on // construction, and sets back to the previous assignment on destruction. -// -// TODO(b/74197823): This is a part of a NOT YET ready refactor. class XlaScopedShardingAssignment { public: XlaScopedShardingAssignment(xla::XlaBuilder* builder, diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc index ce984564d016ce65fa6c932f3cda290cc0d75a4a..2df3ea3af0d4fcfb9bc803feebd96f09042ab1f3 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc @@ -76,7 +76,7 @@ TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { auto y = b.Parameter(1, y_shape, "y"); auto add = b.Add(x, y, /*broadcast_dimensions=*/{0, 1}); - TF_ASSERT_OK_AND_ASSIGN(auto add_shape, add.GetShape()); + TF_ASSERT_OK_AND_ASSIGN(auto add_shape, b.GetShape(add)); EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape)); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); @@ -188,8 +188,10 @@ TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { builder.Add(p0, p0); auto statusor = builder.Build(); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Do not add XlaOp from builder b1 to builder main")); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "built by builder 'b1', but is trying to use it in builder 'main'")); } TEST_F(XlaBuilderTest, ReshapeDefaultOrder) { diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.h b/tensorflow/compiler/xla/client/xla_client/xla_computation.h index b70b57e9ffec40188f246f5e884146012c02f4a2..0ffba208b1f8683fe1d26107cbfd096b856267f1 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_computation.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.h @@ -25,8 +25,6 @@ limitations under the License. namespace xla { // The computation graph that the user builds up with the XlaBuilder. -// -// TODO(b/74197823): Replace xla::Computation with this one. class XlaComputation { public: XlaComputation() : unique_id_(-1) {} diff --git a/tensorflow/compiler/xla/error_spec.h b/tensorflow/compiler/xla/error_spec.h new file mode 100644 index 0000000000000000000000000000000000000000..a1463aa15941b9c265db94e2eb3cc176fab6695b --- /dev/null +++ b/tensorflow/compiler/xla/error_spec.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_ERROR_SPEC_H_ +#define TENSORFLOW_COMPILER_XLA_ERROR_SPEC_H_ + +namespace xla { + +// Structure describing permissible absolute and relative error bounds. +struct ErrorSpec { + explicit ErrorSpec(float aabs, float arel = 0, bool relaxed_nans = false) + : abs(aabs), rel(arel), relaxed_nans(relaxed_nans) {} + + float abs; // Absolute error bound. + float rel; // Relative error bound. + + // If relaxed_nans is true then any result is valid if we are expecting NaNs. + // In effect, this allows the tested operation to produce incorrect results + // for inputs outside its mathematical domain. + bool relaxed_nans; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_ERROR_SPEC_H_ diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index c6f8f6766e9d0156d0c68306af214443f584a9fe..a76fdcda250168cbed2acd01bdd9ddc3b4c93b92 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -140,8 +140,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { LayoutUtil::SetToDefaultLayout(program_shape->mutable_result()); } -/* static */ tensorflow::Status LayoutUtil::ValidateLayoutInShape( - const Shape& shape) { +/* static */ Status LayoutUtil::ValidateLayoutInShape(const Shape& shape) { if (ShapeUtil::IsTuple(shape)) { // Tuple shape. if (shape.has_layout()) { @@ -150,12 +149,12 @@ Layout CreateDefaultLayoutForRank(int64 rank) { for (auto& element_shape : shape.tuple_shapes()) { TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape)); } - return tensorflow::Status::OK(); + return Status::OK(); } else if (ShapeUtil::IsOpaque(shape)) { if (shape.has_layout()) { return InvalidArgument("opaque should not have a layout field"); } - return tensorflow::Status::OK(); + return Status::OK(); } else { // Array shape. if (!shape.has_layout()) { @@ -166,14 +165,14 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } } -/* static */ tensorflow::Status LayoutUtil::ValidateLayoutForShape( - const Layout& layout, const Shape& shape) { +/* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout, + const Shape& shape) { if (ShapeUtil::IsTuple(shape)) { return InvalidArgument("a single Layout is not valid for tuple shapes"); } if (ShapeUtil::IsOpaque(shape)) { - return tensorflow::Status::OK(); + return Status::OK(); } if (layout.format() == INVALID_FORMAT) { @@ -225,7 +224,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } } - return tensorflow::Status::OK(); + return Status::OK(); } /* static */ void LayoutUtil::ClearLayout(Shape* shape) { @@ -384,7 +383,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { namespace { // Internal helper for recursively copying layouts. -tensorflow::Status CopyLayoutInternal(const Shape& src, Shape* dst) { +Status CopyLayoutInternal(const Shape& src, Shape* dst) { if (ShapeUtil::IsTuple(src) != ShapeUtil::IsTuple(*dst)) { return InvalidArgument( "cannot copy layout from shape: shape structure differs"); @@ -411,14 +410,13 @@ tensorflow::Status CopyLayoutInternal(const Shape& src, Shape* dst) { dst->clear_layout(); } } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace /* static */ -tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, - Shape* dst) { +Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { return CopyLayoutInternal(src, dst); } diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 6cec7501015e2dff6b5e56e20b793a5458618501..d3d6a2cc94012f7113fd1cb1b17e9c9d5323d9bf 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -20,9 +20,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -61,12 +61,12 @@ class LayoutUtil { static void SetToDefaultLayout(ProgramShape* program_shape); // Validates that the layout within the given shape is correct. - static tensorflow::Status ValidateLayoutInShape(const Shape& shape); + static Status ValidateLayoutInShape(const Shape& shape); // Validates that the provided layout satisfies invariants for the given // shape. - static tensorflow::Status ValidateLayoutForShape(const Layout& layout, - const Shape& shape); + static Status ValidateLayoutForShape(const Layout& layout, + const Shape& shape); // Clears the layout in the given Shape. After this function is called, // HasLayout will return false for the shape. @@ -179,8 +179,7 @@ class LayoutUtil { // tuples. 'src' and 'dst' need not be compatible but the two shapes must // have the same tuple structure (if any) and arrays must have the same // rank. within the shapes must have the same number of dimensions. - static tensorflow::Status CopyLayoutBetweenShapes(const Shape& src, - Shape* dst); + static Status CopyLayoutBetweenShapes(const Shape& src, Shape* dst); // Returns true if the layouts of lhs and rhs are equal, false // otherwise. Recursively compares layouts of tuples. diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index bc8405703b02dc1b0c4c87005ea3c15372552157..f42fb92359f40ec763866af094972046f6407ae1 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -47,6 +47,12 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { // Set cudnn batchnorm off by default; it does not provide a performance win // on average. flags->set_xla_gpu_use_cudnn_batchnorm(false); + + // Run all GPU work on one stream by default. Using multiple streams + // increases memory usage and we lack strong motivating benchmarks for tuning + // the heuristics needed to decide when to run on multiple streams. See + // b/77879207. + flags->set_xla_gpu_disable_multi_streaming(true); } // Allocates flag_values and flag_objects; this function must not be called more diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc new file mode 100644 index 0000000000000000000000000000000000000000..3696fdbe12e311af3b286ef0dfe91377983b72dd --- /dev/null +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -0,0 +1,739 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/literal_comparison.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" + +using tensorflow::strings::Appendf; +using tensorflow::strings::Printf; +using tensorflow::strings::StrAppend; +using tensorflow::strings::StrCat; + +namespace xla { +namespace literal_comparison { +namespace { + +// Helper function for comparing a floating point type, FloatT, bitwise equal +// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT +// -- on miscompare, a nice error message is given in the AssertionFailure. +template +Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { + auto ulhs = tensorflow::bit_cast(lhs); + auto urhs = tensorflow::bit_cast(rhs); + auto lhs_double = static_cast(lhs); + auto rhs_double = static_cast(rhs); + if (ulhs != urhs) { + return InvalidArgument( + "floating values are not bitwise-equal; and equality testing " + "was requested: %s=%g=%a vs %s=%g=%a", + StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, lhs_double, + StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double); + } + return Status::OK(); +} + +// Templated comparator that specializes for float equality comparison with the +// bitwise helper above (this is the un-specialized fallback, to just use the +// default gunit implementation). +template +Status CompareEqual(NativeT lhs, NativeT rhs) { + if (lhs == rhs) { + return Status::OK(); + } + return InvalidArgument("Expected equality of these values:\n %s\n %s", + StrCat(lhs).c_str(), StrCat(rhs).c_str()); +} + +// Specializations for floating types that do bitwise comparisons when equality +// comparison is requested. +template <> +Status CompareEqual(bfloat16 lhs, bfloat16 rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(Eigen::half lhs, Eigen::half rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(float lhs, float rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(double lhs, double rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(complex64 lhs, complex64 rhs) { + auto res = CompareEqual(lhs.real(), rhs.real()); + if (!res.ok()) { + return res; + } + return CompareEqual(lhs.imag(), rhs.imag()); +} + +// A recursive function which iterates through every index of expected and +// actual literal and compares their values elementwise. Returns true if all +// elements are equal. +template +Status Equal(LiteralSlice expected, LiteralSlice actual, + tensorflow::gtl::MutableArraySlice multi_index, + int64 dimension) { + if (dimension == expected.shape().dimensions_size()) { + NativeT expected_value = expected.Get(multi_index); + NativeT actual_value = actual.Get(multi_index); + return CompareEqual(expected_value, actual_value); + } + + Status result; + for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { + multi_index[dimension] = i; + result.Update(Equal(expected, actual, multi_index, dimension + 1)); + } + return result; +} + +// Gets the total element count. For tuples, this is not the count of tuple +// elements, but the sum of elements of each tuple element. +int64 RecursiveElementCount(const Shape& shape) { + if (ShapeUtil::IsTuple(shape)) { + const int64 tuple_elements = ShapeUtil::TupleElementCount(shape); + int64 total = 0; + for (int64 i = 0; i < tuple_elements; ++i) { + total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); + } + return total; + } else { + return ShapeUtil::ElementsIn(shape); + } +} + +// Returns whether the actual and expected values are mismatched with respect to +// nans. 'relaxed_nans' is interpreted as in xla::ErrorSpec. +template +bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) { + if (relaxed_nans) { + return !std::isnan(expected) && std::isnan(actual); + } else { + return std::isnan(expected) != std::isnan(actual); + } +} + +template <> +bool NanMismatch(complex64 expected, complex64 actual, + bool relaxed_nans) { + return NanMismatch(expected.real(), actual.real(), relaxed_nans) || + NanMismatch(expected.imag(), actual.imag(), relaxed_nans); +} + +template <> +bool NanMismatch(half expected, half actual, bool relaxed_nans) { + return NanMismatch(static_cast(expected), + static_cast(actual), relaxed_nans); +} + +// Converts the given floating-point value to a string. +template +string FpValueToString(NativeT value) { + return Printf("%8.4g", static_cast(value)); +} + +template <> +string FpValueToString(complex64 value) { + return Printf("%8.4g + %8.4fi", value.real(), value.imag()); +} + +// Returns the absolute value of the given floating point value. This function +// is used instead of std::abs directly in order to allow type-dependent +// implementations for NearComparator. +template +float FpAbsoluteValue(NativeT value) { + return std::abs(value); +} + +template <> +float FpAbsoluteValue(bfloat16 value) { + return FpAbsoluteValue(static_cast(value)); +} + +template <> +float FpAbsoluteValue(half value) { + return FpAbsoluteValue(static_cast(value)); +} + +// Helper class for comparing floating-point literals within an error bound. +template +class NearComparator { + public: + // Compares the two array literals elementwise and returns a comparison + // result. The comparison is ok() if all actual and expected elements are + // within the given error bound. In case of error, the status contains a + // detailed message about the discrepancy. + static Status Compare(const LiteralSlice& expected, + const LiteralSlice& actual, ErrorSpec error, + bool detailed_message, + const MiscompareCallback& miscompare_callback) { + NearComparator comparator(expected, actual, error, + detailed_message, miscompare_callback); + return comparator.Run(); + } + + private: + // Data structure encapsulating metadata about a single element mismatch. + struct Mismatch { + NativeT actual; + NativeT expected; + float rel_error; + float abs_error; + + // The linear index of the failure within the shape. This linear index is + // from the 'actual' literal. + int64 linear_index; + + bool operator<(const Mismatch& other) const { + return rel_error < other.rel_error; + } + + string ToString(const Shape& shape) const { + return Printf( + "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g", + FpValueToString(actual).c_str(), FpValueToString(expected).c_str(), + Literal::MultiIndexAsString( + IndexUtil::LinearIndexToMultidimensionalIndex(shape, + linear_index)) + .c_str(), + rel_error, abs_error); + } + }; + + NearComparator(const LiteralSlice& expected, const LiteralSlice& actual, + ErrorSpec error, bool detailed_message, + const MiscompareCallback& miscompare_callback) + : expected_(expected), + actual_(actual), + error_(error), + detailed_message_(detailed_message), + miscompare_callback_(miscompare_callback), + abs_value_buckets_(kAbsValueBucketBounds.size() - 1, {0, 0}), + abs_error_buckets_(kErrorBucketBounds.size(), 0), + rel_error_buckets_(kErrorBucketBounds.size(), 0) {} + + // Runs the comparison between expected and actual literals. + Status Run() { + VLOG(1) << "expected:"; + XLA_VLOG_LINES(1, ToStringTruncated(expected_)); + VLOG(1) << "actual:"; + XLA_VLOG_LINES(1, ToStringTruncated(actual_)); + + // If the shapes mismatch, we simply fail the expectation instead of + // printing out data, as it's a type error rather than a value error. + TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape())); + if (!ShapeUtil::IsArray(expected_.shape())) { + return InvalidArgument("Expected array shape; got %s.", + ShapeUtil::HumanString(expected_.shape()).c_str()); + } + + mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED)); + mismatches_.PopulateWithValue(false); + + CompareLiterals(); + + if (num_mismatches_ == 0) { + return Status::OK(); + } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) { + miscompare_callback_(expected_, actual_, mismatches_); + } + return InvalidArgument("%s", ErrorMessage().c_str()); + } + + // Insert the given absolute value into the absolute value bucket vector. The + // bounds of the buckets are given by kAbsValueBucketBounds. + void UpdateAbsValueBucket(NativeT value, bool is_mismatch) { + // Adjust the bucket containing the absolute values of the 'actual' + // elements. + const float abs_value = FpAbsoluteValue(value); + for (int i = 0; i < abs_value_buckets_.size(); ++i) { + if (i == abs_value_buckets_.size() - 1 || + (abs_value >= kAbsValueBucketBounds[i] && + abs_value < kAbsValueBucketBounds[i + 1])) { + // The first value of the pair is the count of elements in the bucket, + // the second is the count of mismatches in the bucket. + abs_value_buckets_[i].first++; + if (is_mismatch) { + abs_value_buckets_[i].second++; + } + return; + } + } + } + + // Insert the given error into the given error bucket vector. + void UpdateErrorBucket( + float error, tensorflow::gtl::MutableArraySlice error_buckets) { + CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size()); + for (int i = 0; i < error_buckets.size(); ++i) { + if (error >= kErrorBucketBounds[i]) { + error_buckets[i]++; + } + } + } + + // Compares the two given elements from the expected and actual literals at + // the given literal_index and keeps track of various mismatch statistics. + void CompareValues(NativeT expected, NativeT actual, int64 linear_index) { + const bool is_nan_mismatch = + NanMismatch(expected, actual, error_.relaxed_nans); + float abs_error; + float rel_error; + if (actual == expected) { + abs_error = 0; + rel_error = 0; + } else if (is_nan_mismatch) { + num_nan_mismatches_++; + // A nan mismatch is considered to have infinite error. rel_error is used + // for sorting a std::set of the top mismatchs, and a nan value here will + // result in undefined behavior because nan's do not satisfy the strict + // weak ordering requirement of std containers. + abs_error = std::numeric_limits::infinity(); + rel_error = std::numeric_limits::infinity(); + } else { + abs_error = FpAbsoluteValue(actual - expected); + rel_error = abs_error / FpAbsoluteValue(expected); + } + const bool is_abs_mismatch = abs_error > error_.abs; + const bool is_rel_mismatch = rel_error > error_.rel; + const bool is_mismatch = + is_nan_mismatch || (is_abs_mismatch && is_rel_mismatch); + + // Update the error of the relative bucket only if the *absolute* error + // bound is exceeded and vice versa. + if (is_abs_mismatch) { + num_abs_mismatches_++; + UpdateErrorBucket(rel_error, &rel_error_buckets_); + } + if (is_rel_mismatch) { + num_rel_mismatches_++; + UpdateErrorBucket(abs_error, &abs_error_buckets_); + } + + UpdateAbsValueBucket(actual, is_mismatch); + + if (!is_mismatch) { + return; + } + + num_mismatches_++; + + // Keep track of the kTopRelativeErrorCount relative error mismatches. + if (top_rel_mismatches_.size() < kTopRelativeErrorCount || + rel_error > top_rel_mismatches_.begin()->rel_error) { + Mismatch mismatch = {actual, expected, rel_error, abs_error, + linear_index}; + top_rel_mismatches_.insert(mismatch); + if (top_rel_mismatches_.size() > kTopRelativeErrorCount) { + top_rel_mismatches_.erase(top_rel_mismatches_.begin()); + } + } + + mismatches_.data()[linear_index] = true; + } + + // Compares the two literals elementwise. + void CompareLiterals() { + // Fast path optimization for the case were layouts match. + if (LayoutUtil::Equal(actual_.shape().layout(), + expected_.shape().layout())) { + tensorflow::gtl::ArraySlice expected_data = + expected_.data(); + tensorflow::gtl::ArraySlice actual_data = + actual_.data(); + const int64 len = expected_data.size(); + for (int64 i = 0; i < len; ++i) { + CompareValues(expected_data[i], actual_data[i], i); + } + return; + } + std::vector multi_index(ShapeUtil::Rank(actual_.shape()), 0); + CompareLiteralsSlow(0, &multi_index); + } + + // Slow path for CompareLiterals when 'actual' and 'expected' literals have + // different layouts. In this case, multidimensional indices are constructed + // and indexed for each element. + void CompareLiteralsSlow(int64 dimension, std::vector* multi_index) { + if (dimension == multi_index->size()) { + CompareValues(expected_.Get(*multi_index), + actual_.Get(*multi_index), + IndexUtil::MultidimensionalIndexToLinearIndex( + actual_.shape(), *multi_index)); + } else { + for (int64 i = 0; i < expected_.shape().dimensions(dimension); ++i) { + (*multi_index)[dimension] = i; + CompareLiteralsSlow(dimension + 1, multi_index); + } + } + } + + // Returns an error message string with a detailed breakdown of the + // mismatches. Called after calling Run(). + string ErrorMessage() { + string out; + int64 element_count = ShapeUtil::ElementsIn(actual_.shape()); + + auto percent_string = [](float a, float b) { + float pct = b == 0.0 ? 0.0 : 100.0 * a / b; + return Printf("%0.4f%%", pct); + }; + + Appendf(&out, + "\nMismatch count %lld (%s) in shape %s (%lld elements), abs bound " + "%g, rel bound %g\n", + num_mismatches_, + percent_string(num_mismatches_, element_count).c_str(), + ShapeUtil::HumanString(actual_.shape()).c_str(), + ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel); + if (num_nan_mismatches_ > 0) { + StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n"); + } + Appendf(&out, "Top relative error mismatches:\n"); + for (auto it = top_rel_mismatches_.rbegin(); + it != top_rel_mismatches_.rend(); ++it) { + StrAppend(&out, " ", it->ToString(actual_.shape()).c_str(), "\n"); + } + + if (!detailed_message_) { + return out; + } + + StrAppend(&out, "Absolute magnitude breakdown of actual values:\n"); + CHECK_EQ(abs_value_buckets_.size() + 1, kAbsValueBucketBounds.size()); + for (int i = 0; i < abs_value_buckets_.size(); ++i) { + const int64 bucket_size = abs_value_buckets_[i].first; + const int64 bucket_mismatches = abs_value_buckets_[i].second; + string mismatch_str = bucket_mismatches > 0 + ? Printf(", mismatches %lld", bucket_mismatches) + : ""; + Appendf(&out, " %-6g <= x < %-6g : %7lld (%9s)%s\n", + kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1], + bucket_size, percent_string(bucket_size, element_count).c_str(), + mismatch_str.c_str()); + } + + auto print_accum_buckets = [&](const string& header, int64 total, + tensorflow::gtl::ArraySlice buckets) { + StrAppend(&out, header, ":\n"); + Appendf(&out, " < %-6g : %7lld (%s)\n", kErrorBucketBounds[0], + total - buckets[0], + percent_string(total - buckets[0], total).c_str()); + CHECK_EQ(buckets.size(), kErrorBucketBounds.size()); + for (int i = 0; i < kErrorBucketBounds.size(); ++i) { + Appendf(&out, " >= %-6g : %7lld (%s)\n", kErrorBucketBounds[i], + buckets[i], percent_string(buckets[i], total).c_str()); + } + }; + Appendf(&out, "Elements exceeding abs error bound %g: %lld (%s)\n", + error_.abs, num_abs_mismatches_, + percent_string(num_abs_mismatches_, element_count).c_str()); + print_accum_buckets( + "Relative error breakdown of elements exceeding abs error bound", + num_abs_mismatches_, rel_error_buckets_); + Appendf(&out, "Elements exceeding rel error bound %g: %lld (%s)\n", + error_.rel, num_rel_mismatches_, + percent_string(num_rel_mismatches_, element_count).c_str()); + print_accum_buckets( + "Absolute error breakdown of elements exceeding rel error bound", + num_rel_mismatches_, abs_error_buckets_); + return out; + } + + // 'actual' and 'expected' literals being compared. + LiteralSlice expected_; + LiteralSlice actual_; + + // The error bounds of the comparison. + ErrorSpec error_; + + // Whether to include detailed breakdown of mismatches in the error message. + bool detailed_message_; + + // Callback to invoke on miscompare. + MiscompareCallback miscompare_callback_; + + // Number of element element mismatches encountered so far. + int64 num_mismatches_ = 0; + + // Number of elements with a nan mismatch. + int64 num_nan_mismatches_ = 0; + + // Number of elements which exceed the absolute/relative error bound. + int64 num_abs_mismatches_ = 0; + int64 num_rel_mismatches_ = 0; + + // A Literal containing which elements did not match in the expected and + // actual literals. mismatches_ contains PREDs and is of the same sizes as + // the comparison literals. + Literal mismatches_; + + // The number of mismatches to report in the output, sorted by relative error + // magnitude. + static constexpr int64 kTopRelativeErrorCount = 5; + + // The set of mismatches with the largest relative error. The size of this set + // is bounded by kTopRelativeErrorCount. + std::multiset top_rel_mismatches_; + + // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the + // bounds of these buckets. abs_value_buckets_ contains a pair for each + // bucket: the element count and failure count. + static constexpr std::array kAbsValueBucketBounds = { + 0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits::infinity()}; + std::vector> abs_value_buckets_; + + // Buckets for relative and absolute errors. The relative error buckets only + // contains those elements which exceed the *absolute* error bound, and vice + // versa. This makes it easy to see the effect of adjusting the relative (or + // absolute) error bound on the success of the comparison. kErrorBucketBounds + // are the lower bounds of the buckets in both vectors. The error buckets are + // a cumulative distribution so an error value may appear in more than one + // bucket. For example an error value of 0.003 may appear in the buckets + // bounded by 0.01, 0.1, and 1.0. + static constexpr std::array kErrorBucketBounds = {0.0001, 0.001, + 0.01, 0.1, 1}; + std::vector abs_error_buckets_; + std::vector rel_error_buckets_; +}; + +template +constexpr std::array NearComparator::kAbsValueBucketBounds; +template +constexpr std::array NearComparator::kErrorBucketBounds; + +// Helper function for comparing two literals for nearness. Handles tuple-shapes +// via recursion. shape_index is the ShapeIndex of expected (or actual) +// currently being compared. +Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error, bool detailed_message, + const MiscompareCallback& miscompare_callback, + const ShapeIndex& shape_index) { + TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); + + if (ShapeUtil::IsTuple(expected.shape())) { + Status return_status; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + const auto expected_element = LiteralSlice(expected, {i}); + const auto actual_element = LiteralSlice(actual, {i}); + ShapeIndex element_index = shape_index; + element_index.push_back(i); + Status res = + NearHelper(expected_element, actual_element, error, detailed_message, + miscompare_callback, element_index); + if (!res.ok()) { + string err_message = Printf("\nArray at shape index %s%s", + element_index.ToString().c_str(), + res.error_message().c_str()); + if (return_status.ok()) { + return_status = res; + } else { + return_status = AppendStatus(return_status, res.error_message()); + } + } + } + if (!return_status.ok() && shape_index.empty()) { + // Emit a top-level error message containing the top-level shape in case + // of mismatch. + int64 total_elements = RecursiveElementCount(actual.shape()); + return_status = InvalidArgument( + "\nMismatches in shape %s (%lld elements):\n%s", + ShapeUtil::HumanString(actual.shape()).c_str(), total_elements, + return_status.error_message().c_str()); + } + return return_status; + } + + if (ShapeUtil::ElementIsFloating(expected.shape()) || + ShapeUtil::ElementIsComplex(expected.shape())) { + switch (expected.shape().element_type()) { + case BF16: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case F16: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case F32: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case F64: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case C64: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + default: + LOG(FATAL) << "Unsupported primitive type in near comparator: " + << PrimitiveType_Name(expected.shape().element_type()) + << ". Must be floating-point type."; + } + } + + // Non-floating point literal. + return literal_comparison::Equal(expected, actual); +} + +} // namespace + +Status EqualShapes(const Shape& expected, const Shape& actual) { + if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { + return InvalidArgument("tupleness-mismatch! want: %s got %s", + ShapeUtil::HumanString(expected).c_str(), + ShapeUtil::HumanString(actual).c_str()); + } + if (ShapeUtil::IsTuple(expected)) { + if (ShapeUtil::TupleElementCount(expected) != + ShapeUtil::TupleElementCount(actual)) { + return InvalidArgument( + "want tuple element count: %lld got tuple element count: %lld", + ShapeUtil::TupleElementCount(expected), + ShapeUtil::TupleElementCount(actual)); + } + for (int i = 0; i < expected.tuple_shapes_size(); ++i) { + Status result = + EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + if (!result.ok()) { + return AppendStatus(result, StrCat("mismatch in tuple index", i)); + } + } + } else { + if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { + return InvalidArgument("want rank of %s got rank of %s", + ShapeUtil::HumanString(expected).c_str(), + ShapeUtil::HumanString(actual).c_str()); + } + if (expected.element_type() != actual.element_type()) { + return InvalidArgument( + "mismatch in primitive type %s vs %s", + PrimitiveType_Name(expected.element_type()).c_str(), + PrimitiveType_Name(actual.element_type()).c_str()); + } + if (expected.dimensions_size() != actual.dimensions_size()) { + return InvalidArgument("want dimensions_size %d got dimensions_size %d", + expected.dimensions_size(), + actual.dimensions_size()); + } + for (int i = 0; i < expected.dimensions_size(); ++i) { + if (expected.dimensions(i) != actual.dimensions(i)) { + return InvalidArgument( + "mismatch in dimension #%d expected: %s actual: %s", i, + ShapeUtil::HumanString(expected).c_str(), + ShapeUtil::HumanString(actual).c_str()); + } + } + } + return Status::OK(); +} + +Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { + VLOG(1) << "expected:"; + XLA_VLOG_LINES(1, expected.ToString()); + VLOG(1) << "actual:"; + XLA_VLOG_LINES(1, actual.ToString()); + + TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); + std::vector multi_index(expected.shape().dimensions_size(), 0); + Status result; + switch (expected.shape().element_type()) { + case PRED: + result = Equal(expected, actual, &multi_index, 0); + break; + case U8: + result = Equal(expected, actual, &multi_index, 0); + break; + case S32: + result = Equal(expected, actual, &multi_index, 0); + break; + case S64: + result = Equal(expected, actual, &multi_index, 0); + break; + case U32: + result = Equal(expected, actual, &multi_index, 0); + break; + case U64: + result = Equal(expected, actual, &multi_index, 0); + break; + case BF16: + result = Equal(expected, actual, &multi_index, 0); + break; + case F16: + result = Equal(expected, actual, &multi_index, 0); + break; + case F32: + result = Equal(expected, actual, &multi_index, 0); + break; + case F64: + result = Equal(expected, actual, &multi_index, 0); + break; + case C64: + result = Equal(expected, actual, &multi_index, 0); + break; + case TUPLE: { + for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + result.Update( + Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}))); + } + break; + } + default: + LOG(FATAL) + << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " + << PrimitiveType_Name(expected.shape().element_type()); + } + + if (result.ok()) { + return Status::OK(); + } + + return AppendStatus(result, + tensorflow::strings::Printf("expected: %s\nactual: %s", + expected.ToString().c_str(), + actual.ToString().c_str())); +} + +Status Near(const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error, bool detailed_message, + const MiscompareCallback& miscompare_callback) { + return NearHelper(expected, actual, error, detailed_message, + miscompare_callback, + /*shape_index=*/{}); +} + +string ToStringTruncated(const LiteralSlice& literal) { + return RecursiveElementCount(literal.shape()) < 1000 + ? literal.ToString() + : "[TRUNCATED, Literal with more than 1000 values]"; +} + +} // namespace literal_comparison +} // namespace xla diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h new file mode 100644 index 0000000000000000000000000000000000000000..00a13e361932e74a9a1e614d5c851d3851208852 --- /dev/null +++ b/tensorflow/compiler/xla/literal_comparison.h @@ -0,0 +1,72 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Library for comparing literals without taking a dependency on testing +// libraries. + +#ifndef TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ +#define TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ + +#include "tensorflow/compiler/xla/error_spec.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { +namespace literal_comparison { + +// Returns ok if the given shapes have the same rank, dimension sizes, and +// primitive types. +Status EqualShapes(const Shape& expected, const Shape& actual); + +// Returns ok if the expected and actual literals are (bitwise) equal for all +// elements in the literal. Also, asserts that the rank, dimensions sizes, and +// primitive type are equal. +Status Equal(const LiteralSlice& expected, const LiteralSlice& actual); + +using MiscompareCallback = + std::function; + +// Inspects whether the expected and actual literals are within the given error +// bound for all elements. Also, inspects whether the rank, dimensions sizes, +// and dimension bounds are equivalent. +// +// Tuples are matched recursively. +// +// When comparing tensors of non-floating-point type, this inspects for exact +// equality, ignoring the ErrorSpec. +// +// If the shape of the literals is neither a complex/floating-point tensor nor a +// tuple which contains a complex/floating-point tensor, Near() is equivalent to +// Equal(). We don't raise an error in this case, because we want to allow +// callers to call Near() even if they have no preconceptions about the shapes +// being compared. +// +// If detailed_message is true, then the error message in the assertion result +// will contain a more detailed breakdown of mismatches. +Status Near(const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error, bool detailed_message, + const MiscompareCallback& miscompare_callback); + +// Calling ToString on a literal with over 100 million elements takes around +// 3 minutes. The utility of printing a literal with >1000 elements is +// questionable, especially when writing the Literal proto to disk is orders +// of magnitude faster. +string ToStringTruncated(const LiteralSlice& literal); + +} // namespace literal_comparison +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index b3b5e34ba220c7e9bf1cefef4b27baa6faee2c20..4c560767dc603bf805f365d594810f4df7e90ed3 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -62,8 +62,49 @@ void ConvertEndianShort(char* bytes, int64 size) { } } +// Return a literal with all arrays of type FromNativeT converted to type +// ToNativeT in the given literal. +template +std::unique_ptr ConvertType(LiteralSlice literal) { + // First construct shape of the result. + Shape result_shape(literal.shape()); + ShapeUtil::ForEachMutableSubshape( + &result_shape, [](Shape* subshape, const ShapeIndex&) { + if (subshape->element_type() == + primitive_util::NativeToPrimitiveType()) { + subshape->set_element_type( + primitive_util::NativeToPrimitiveType()); + } + }); + auto result = MakeUnique(result_shape); + + // Then copy over the data from 'literal' converting FromNativeT values to + // ToNativeT values as necessary. + ShapeUtil::ForEachSubshape( + literal.shape(), + [&](const Shape& subshape, const ShapeIndex& shape_index) { + if (ShapeUtil::IsArray(subshape)) { + if (subshape.element_type() == + primitive_util::NativeToPrimitiveType()) { + auto src = literal.data(shape_index); + auto dest = result->data(shape_index); + for (int64 i = 0; i < src.size(); ++i) { + dest[i] = static_cast(src[i]); + } + } else { + TF_CHECK_OK(result->CopyFrom(literal, + /*dest_shape_index=*/shape_index, + /*src_shape_index=*/shape_index)); + } + } + }); + return result; +} + } // namespace +LiteralBase::~LiteralBase() {} + std::ostream& operator<<(std::ostream& out, const Literal& literal) { out << literal.ToString(); return out; @@ -95,99 +136,89 @@ Literal::StrideConfig::StrideConfig( Literal::Literal(const Shape& shape) : Literal(shape, /*allocate_arrays=*/true) {} -Literal::Literal(const Shape& shape, bool allocate_arrays) - : shape_(shape), pieces_(shape), owns_buffers_(true) { - CHECK(LayoutUtil::HasLayout(shape)); - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - const Shape& subshape = piece.subshape(); - if (ShapeUtil::IsArray(subshape)) { - if (allocate_arrays) { - if (LayoutUtil::IsSparseArray(subshape)) { - // For sparse arrays, the buffer must be of the size of the maximum - // number of sparse elements possible. - const int64 max_sparse_elements = - LayoutUtil::MaxSparseElements(subshape.layout()); - piece.set_buffer( - new char[max_sparse_elements * ShapeUtil::ByteSizeOfPrimitiveType( - subshape.element_type())]); - piece.set_sparse_indices(new SparseIndexArray( - max_sparse_elements, ShapeUtil::Rank(subshape))); - } else { - piece.set_buffer(new char[piece.size_bytes()]); - } +void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + SetPiece(subshape, &child_piece, allocate_arrays); + + piece->emplace_back(std::move(child_piece)); + } + } else { + CHECK(ShapeUtil::IsArray(shape)); + if (allocate_arrays) { + if (LayoutUtil::IsSparseArray(shape)) { + // For sparse arrays, the buffer must be of the size of the maximum + // number of sparse elements possible. + const int64 max_sparse_elements = + LayoutUtil::MaxSparseElements(shape.layout()); + piece->set_buffer( + new char[max_sparse_elements * + ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]); + piece->set_sparse_indices( + new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape))); } else { - piece.set_buffer(nullptr); + piece->set_buffer(new char[piece->size_bytes()]); } } } } -Literal::~Literal() { DeallocateBuffers(); } +Literal::Literal(const Shape& shape, bool allocate_arrays) + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(LayoutUtil::HasLayout(*shape_)); + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + CHECK(&root_piece_->subshape() == shape_.get()); -void Literal::DeallocateBuffers() { - if (owns_buffers_) { - for (auto& pair : pieces_) { - Piece& piece = pair.second; - if (piece.buffer() != nullptr) { - delete[] piece.buffer(); - delete piece.sparse_indices(); - } - } - } + SetPiece(*shape_, root_piece_, allocate_arrays); } -Literal::Literal(Literal&& other) { - shape_ = std::move(other.shape_); - pieces_ = std::move(other.pieces_); - // We need to iterate through the pieces to set the subshape pointer - // properly. It must refer to subshapes within shape_. - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); +Literal::~Literal() { + if (root_piece_ != nullptr) { + DeallocateBuffers(); + delete root_piece_; } - owns_buffers_ = other.owns_buffers_; +} - other.shape_ = ShapeUtil::MakeNil(); - other.pieces_ = ShapeTree(other.shape_); - other.piece({}).set_subshape(&other.shape_); +void Literal::DeallocateBuffers() { + root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (piece->buffer() != nullptr) { + delete[] piece->buffer(); + delete piece->sparse_indices(); + } + }); } +Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); } + Literal& Literal::operator=(Literal&& other) { - DeallocateBuffers(); - shape_ = std::move(other.shape_); - pieces_ = std::move(other.pieces_); - // We need to iterate through the pieces to set the subshape pointer - // properly. It must refer to subshapes within shape_. - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - } - owns_buffers_ = other.owns_buffers_; - - other.shape_ = ShapeUtil::MakeNil(); - other.pieces_ = ShapeTree(other.shape_); - other.piece({}).set_subshape(&other.shape_); + DCHECK(&other.root_piece_->subshape() == other.shape_.get()); + using std::swap; + swap(shape_, other.shape_); + swap(root_piece_, other.root_piece_); + DCHECK(&root_piece_->subshape() == shape_.get()); + return *this; } -std::unique_ptr Literal::CreateFromShape(const Shape& shape) { +std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { auto literal = MakeUnique(shape); - for (auto& pair : literal->pieces_) { - Piece& piece = pair.second; - if (ShapeUtil::IsArray(piece.subshape())) { - memset(piece.untyped_data(), 0, piece.size_bytes()); - } - } + literal->root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (ShapeUtil::IsArray(piece->subshape())) { + memset(piece->untyped_data(), 0, piece->size_bytes()); + } + }); return literal; } -const SparseIndexArray* Literal::sparse_indices( +const SparseIndexArray* LiteralBase::sparse_indices( const ShapeIndex& shape_index) const { return piece(shape_index).sparse_indices(); } @@ -202,9 +233,19 @@ SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions)); } +/* static */ std::unique_ptr Literal::ConvertBF16ToF32( + const LiteralSlice& bf16_literal) { + return ConvertType(bf16_literal); +} + +/* static */ std::unique_ptr Literal::ConvertF32ToBF16( + const LiteralSlice& f32_literal) { + return ConvertType(f32_literal); +} + template Status Literal::CopySliceFromInternal( - const Literal& src_literal, tensorflow::gtl::ArraySlice src_base, + const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size) { TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); @@ -264,7 +305,7 @@ Status Literal::CopySliceFromInternal( return Status::OK(); } -Status Literal::CopyElementFrom(const Literal& src_literal, +Status Literal::CopyElementFrom(const LiteralSlice& src_literal, tensorflow::gtl::ArraySlice src_index, tensorflow::gtl::ArraySlice dest_index) { DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); @@ -293,22 +334,21 @@ std::vector Literal::DecomposeTuple() { elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}), /*allocate_arrays=*/false)); Literal& element = elements.back(); - for (auto& pair : element.pieces_) { - const ShapeIndex& index = pair.first; - Piece& dest_piece = pair.second; - ShapeIndex src_index = {i}; - for (int64 j : index) { - src_index.push_back(j); - } - Piece& src_piece = piece(src_index); - - // Move the respective buffer and sparse indices over to the element - // Literal. - dest_piece.set_buffer(src_piece.buffer()); - src_piece.set_buffer(nullptr); - dest_piece.set_sparse_indices(src_piece.sparse_indices()); - src_piece.set_sparse_indices(nullptr); - } + element.root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* dest_piece) { + ShapeIndex src_index = {i}; + for (int64 j : index) { + src_index.push_back(j); + } + Piece& src_piece = piece(src_index); + + // Move the respective buffer and sparse indices over to the element + // Literal. + dest_piece->set_buffer(src_piece.buffer()); + src_piece.set_buffer(nullptr); + dest_piece->set_sparse_indices(src_piece.sparse_indices()); + src_piece.set_sparse_indices(nullptr); + }); } // Set this literal to be nil-shaped. *this = Literal(); @@ -351,7 +391,9 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, } // namespace -Status Literal::Piece::CopyFrom(const Literal::Piece& src) { +Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { + CHECK(subshape_ != nullptr); + CHECK(src.subshape_ != nullptr); if (ShapeUtil::Equal(subshape(), src.subshape())) { // If the layouts are equal it's faster just to memcpy. memcpy(buffer(), src.buffer(), src.size_bytes()); @@ -388,7 +430,7 @@ Status Literal::Piece::CopyFrom(const Literal::Piece& src) { return Status::OK(); } -Status Literal::CopyFrom(const Literal& src_literal, +Status Literal::CopyFrom(const LiteralSlice& src_literal, const ShapeIndex& dest_shape_index, const ShapeIndex& src_shape_index) { const Shape& dest_subshape = @@ -401,36 +443,32 @@ Status Literal::CopyFrom(const Literal& src_literal, ShapeUtil::HumanString(dest_subshape).c_str(), ShapeUtil::HumanString(src_subshape).c_str()); } + return root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + if (!ShapeUtil::IsArray(piece->subshape())) { + return Status::OK(); + } - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } - - // Determine if this index is in the part of this literal that we want to - // copy over from src_literal. - bool in_subtree_to_copy = true; - for (int i = 0; i < dest_shape_index.size(); ++i) { - if (index[i] != dest_shape_index[i]) { - in_subtree_to_copy = false; - break; - } - } - if (!in_subtree_to_copy) { - continue; - } - - // Construct the index of the corresponding piece in the source literal. - ShapeIndex src_piece_index = src_shape_index; - for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { - src_piece_index.push_back(index[i]); - } - - TF_RETURN_IF_ERROR(piece.CopyFrom(src_literal.piece(src_piece_index))); - } - return Status::OK(); + // Determine if this index is in the part of this literal that we want + // to copy over from src_literal. + bool in_subtree_to_copy = true; + for (int i = 0; i < dest_shape_index.size(); ++i) { + if (index[i] != dest_shape_index[i]) { + in_subtree_to_copy = false; + break; + } + } + if (!in_subtree_to_copy) { + return Status::OK(); + } + // Construct the index of the corresponding piece in the source literal. + ShapeIndex src_piece_index = src_shape_index; + for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { + src_piece_index.push_back(index[i]); + } + TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index))); + return Status::OK(); + }); } Status Literal::MoveFrom(Literal&& src_literal, @@ -444,37 +482,32 @@ Status Literal::MoveFrom(Literal&& src_literal, ShapeUtil::HumanString(src_literal.shape()).c_str()); } - if (!(owns_buffers_ && src_literal.owns_buffers_)) { - return InvalidArgument( - "Source and destination literals must both own their buffers (ie, not " - "be views)"); - } + src_literal.root_piece_->ForEachSubpiece( + [&](const ShapeIndex& src_index, const Piece& src_piece) { + if (!ShapeUtil::IsArray(src_piece.subshape())) { + return; + } - for (auto& pair : src_literal.pieces_) { - const ShapeIndex& src_index = pair.first; - Piece& src_piece = pair.second; - if (!ShapeUtil::IsArray(src_piece.subshape())) { - continue; - } + ShapeIndex dest_index = dest_shape_index; + for (int64 i : src_index) { + dest_index.push_back(i); + } + Piece& dest_piece = piece(dest_index); + delete[] dest_piece.buffer(); + dest_piece.set_buffer(src_piece.buffer()); + delete dest_piece.sparse_indices(); + dest_piece.set_sparse_indices(src_piece.sparse_indices()); + }); - ShapeIndex dest_index = dest_shape_index; - for (int64 i : src_index) { - dest_index.push_back(i); - } - Piece& dest_piece = piece(dest_index); - delete[] dest_piece.buffer(); - dest_piece.set_buffer(src_piece.buffer()); - delete dest_piece.sparse_indices(); - dest_piece.set_sparse_indices(src_piece.sparse_indices()); - } + src_literal.shape_ = MakeUnique(ShapeUtil::MakeNil()); + delete src_literal.root_piece_; + src_literal.root_piece_ = new LiteralBase::Piece(); + src_literal.root_piece_->set_subshape(src_literal.shape_.get()); - src_literal.shape_ = ShapeUtil::MakeNil(); - src_literal.pieces_ = ShapeTree(src_literal.shape_); - src_literal.piece({}).set_subshape(&src_literal.shape_); return Status::OK(); } -Status Literal::CopySliceFrom(const Literal& src_literal, +Status Literal::CopySliceFrom(const LiteralSlice& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size) { @@ -743,7 +776,7 @@ void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { return CreateR2FromArray2D(*value); } -std::unique_ptr Literal::Relayout( +std::unique_ptr LiteralBase::Relayout( const Layout& new_layout, const ShapeIndex& shape_index) const { // Create new shape with 'new_layout' set at the given shape index. Shape new_shape = shape(); @@ -755,7 +788,7 @@ std::unique_ptr Literal::Relayout( return result; } -std::unique_ptr Literal::Relayout( +std::unique_ptr LiteralBase::Relayout( const Shape& shape_with_layout) const { CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) @@ -774,7 +807,7 @@ std::unique_ptr Literal::Relayout( return result; } -StatusOr> Literal::Reshape( +StatusOr> LiteralBase::Reshape( tensorflow::gtl::ArraySlice dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Reshape does not support tuples."); @@ -788,7 +821,8 @@ StatusOr> Literal::Reshape( } // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. - output->shape_ = ShapeUtil::MakeShape(shape().element_type(), dimensions); + *output->mutable_shape_do_not_use() = + ShapeUtil::MakeShape(shape().element_type(), dimensions); int64 elements_before = ShapeUtil::ElementsIn(shape()); int64 elements_after = ShapeUtil::ElementsIn(output->shape()); @@ -802,7 +836,79 @@ StatusOr> Literal::Reshape( return std::move(output); } -std::unique_ptr Literal::Transpose( +/* static */ std::unique_ptr Literal::ReshapeSlice( + tensorflow::gtl::ArraySlice new_dimensions, + tensorflow::gtl::ArraySlice minor_to_major, + const LiteralSlice& literal) { + int64 new_num_elements = 1; + for (int64 i = 0; i < new_dimensions.size(); ++i) { + new_num_elements *= new_dimensions[i]; + } + CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); + CHECK_EQ(new_dimensions.size(), minor_to_major.size()); + + auto new_literal = MakeUnique( + ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); + + // Create a new shape with the given minor-to-major layout. This shape is used + // solely for converting linear address to multi-dimensional addresses when + // writing elements to the new literal. + Shape shape_with_layout = new_literal->shape(); + *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); + + // Copy data into new literal, element-by-element. + for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { + std::vector from_multi_index = + IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); + std::vector to_multi_index = + IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); + switch (literal.shape().element_type()) { + case PRED: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case U8: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case U32: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case S32: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case U64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case S64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case F32: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case F64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case C64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + default: + LOG(FATAL) << "Unhandled primitive element type: " + << PrimitiveType_Name(literal.shape().element_type()); + } + } + + return new_literal; +} + +std::unique_ptr LiteralBase::Transpose( tensorflow::gtl::ArraySlice permutation) const { CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) @@ -833,15 +939,14 @@ std::unique_ptr Literal::Transpose( for (auto index : LayoutUtil::MinorToMajor(shape())) { layout->add_minor_to_major(inverse_permutation[index]); } - std::unique_ptr new_literal = CreateFromShape(permuted_shape); - DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()), + auto new_literal = MakeUnique(permuted_shape); + DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), ShapeUtil::ByteSizeOf(shape())); - std::memcpy(new_literal->root_piece().buffer(), root_piece().buffer(), - root_piece().size_bytes()); + std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); return new_literal; } -std::unique_ptr Literal::Slice( +std::unique_ptr LiteralBase::Slice( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices) const { CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; @@ -909,20 +1014,20 @@ std::unique_ptr Literal::Slice( } } -Literal Literal::Clone() const { +Literal LiteralBase::Clone() const { Literal result(shape()); TF_CHECK_OK(result.CopyFrom(*this)); return result; } -std::unique_ptr Literal::CloneToUnique() const { +std::unique_ptr LiteralBase::CloneToUnique() const { auto result = MakeUnique(shape()); TF_CHECK_OK(result->CopyFrom(*this)); return result; } -string Literal::GetAsString(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const { +string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); CHECK(LayoutUtil::IsDenseArray(subshape)); switch (subshape.element_type()) { @@ -962,8 +1067,8 @@ string Literal::GetAsString(tensorflow::gtl::ArraySlice multi_index, } } -string Literal::GetSparseElementAsString(int64 sparse_element_number, - const ShapeIndex& shape_index) const { +string LiteralBase::GetSparseElementAsString( + int64 sparse_element_number, const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); CHECK(LayoutUtil::IsSparseArray(subshape)); switch (subshape.element_type()) { @@ -1017,7 +1122,7 @@ string Literal::GetSparseElementAsString(int64 sparse_element_number, } } -StatusOr Literal::GetIntegralAsS64( +StatusOr LiteralBase::GetIntegralAsS64( tensorflow::gtl::ArraySlice multi_index) const { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { @@ -1040,6 +1145,27 @@ StatusOr Literal::GetIntegralAsS64( } } +size_t LiteralBase::Hash() const { + using tensorflow::Hash64; + using tensorflow::Hash64Combine; + + size_t hash_value = ShapeUtil::Hash(shape()); + + ShapeUtil::ForEachSubshape( + shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (ShapeUtil::IsTuple(subshape)) { + return; + } + + CHECK(LayoutUtil::IsDense(subshape.layout())); + hash_value = Hash64Combine( + hash_value, Hash64(static_cast(untyped_data(index)), + size_bytes(index))); + }); + + return hash_value; +} + Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, int64 value) { CHECK(LayoutUtil::IsDenseArray(shape())); @@ -1070,7 +1196,7 @@ Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, return Status::OK(); } -tensorflow::gtl::ArraySlice Literal::GetSparseIndex( +tensorflow::gtl::ArraySlice LiteralBase::GetSparseIndex( int64 sparse_element_number, const ShapeIndex& shape_index) const { const Piece& p = piece(shape_index); CHECK_GE(sparse_element_number, 0); @@ -1082,10 +1208,10 @@ void Literal::SortSparseElements(const ShapeIndex& shape_index) { piece(shape_index).SortSparseElements(); } -Literal Literal::GetFirstScalarLiteral() const { - CHECK(ShapeUtil::IsArray(shape_)); - CHECK_GT(ShapeUtil::ElementsIn(shape_), 0); - switch (shape_.element_type()) { +Literal LiteralBase::GetFirstScalarLiteral() const { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_GT(ShapeUtil::ElementsIn(shape()), 0); + switch (shape().element_type()) { case PRED: return std::move(*Literal::CreateR0(GetFirstElement())); // 8 bit types. @@ -1121,11 +1247,11 @@ Literal Literal::GetFirstScalarLiteral() const { case U64: return std::move(*Literal::CreateR0(GetFirstElement())); default: - LOG(FATAL) << "Unhandled primitive type " << shape_.element_type(); + LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); } } -void Literal::Piece::SortSparseElements() { +void LiteralBase::Piece::SortSparseElements() { switch (subshape().element_type()) { case PRED: SortSparseElementsInternal(); @@ -1176,7 +1302,7 @@ void Literal::Piece::SortSparseElements() { } template -void Literal::Piece::SortSparseElementsInternal() { +void LiteralBase::Piece::SortSparseElementsInternal() { CHECK(LayoutUtil::IsSparseArray(subshape())); int64 num_elements = sparse_indices()->index_count(); auto values = data(); @@ -1187,9 +1313,11 @@ void Literal::Piece::SortSparseElementsInternal() { namespace { -void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, +void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, bool print_layout, std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + CHECK(LayoutUtil::HasLayout(literal.shape())); + CHECK(LayoutUtil::HasLayout(subshape)); auto shape_to_string = [print_layout](const Shape& shape) { if (print_layout) { @@ -1348,13 +1476,14 @@ void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, } // namespace -int64 Literal::sparse_element_count() const { +int64 LiteralBase::sparse_element_count() const { CHECK(LayoutUtil::IsSparseArray(shape())); return sparse_indices()->index_count(); } -string Literal::ToString(bool print_layout) const { +string LiteralBase::ToString(bool print_layout) const { std::vector pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); ToStringHelper(*this, {}, print_layout, &pieces); return tensorflow::str_util::Join(pieces, ""); } @@ -1362,7 +1491,7 @@ string Literal::ToString(bool print_layout) const { /* static */ std::unique_ptr Literal::MakeTuple( tensorflow::gtl::ArraySlice elements) { std::vector element_shapes; - for (const Literal* element : elements) { + for (const auto* element : elements) { element_shapes.push_back(element->shape()); } auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); @@ -1372,6 +1501,19 @@ string Literal::ToString(bool print_layout) const { return literal; } +/* static */ std::unique_ptr Literal::MakeTupleFromSlices( + tensorflow::gtl::ArraySlice elements) { + std::vector element_shapes; + for (const auto& element : elements) { + element_shapes.push_back(element.shape()); + } + auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + for (int i = 0; i < elements.size(); ++i) { + TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i})); + } + return literal; +} + /* static */ std::unique_ptr Literal::MakeTupleOwned( std::vector> elements) { std::vector element_shapes; @@ -1387,7 +1529,7 @@ string Literal::ToString(bool print_layout) const { return literal; } -void Literal::EachCellAsString( +void LiteralBase::EachCellAsString( const std::function indices, const string& value)>& per_cell) const { if (ShapeUtil::HasZeroElements(shape())) { @@ -1403,7 +1545,7 @@ void Literal::EachCellAsString( namespace { template std::unique_ptr ConvertBetweenNativeTypesWithConverter( - const Literal& src_literal, const ConverterType& converter) { + const LiteralBase& src_literal, const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( src_literal.shape(), @@ -1419,7 +1561,8 @@ std::unique_ptr ConvertBetweenNativeTypesWithConverter( } template -std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { +std::unique_ptr ConvertBetweenNativeTypes( + const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return static_cast(src); }; return ConvertBetweenNativeTypesWithConverter( src_literal, converter); @@ -1428,7 +1571,7 @@ std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { template typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), std::unique_ptr>::type -BitcastBetweenNativeTypes(const Literal& src_literal) { +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return tensorflow::bit_cast(src); }; @@ -1443,12 +1586,12 @@ BitcastBetweenNativeTypes(const Literal& src_literal) { template typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), std::unique_ptr>::type -BitcastBetweenNativeTypes(const Literal& src_literal) { +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { LOG(FATAL) << "Invalid bitcast between types of different sizes."; } template -std::unique_ptr ConvertToC64(const Literal& src_literal) { +std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); auto result_literal = MakeUnique( ShapeUtil::ChangeElementType(src_literal.shape(), C64)); @@ -1466,7 +1609,7 @@ std::unique_ptr ConvertToC64(const Literal& src_literal) { } template -std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal, +std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); if (bitcast) { @@ -1486,7 +1629,7 @@ std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal, template StatusOr> ConvertIfDestTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type, + const LiteralBase& src_literal, PrimitiveType primitive_dest_type, bool bitcast) { switch (primitive_dest_type) { #define CONVERT_IF_TYPES_MATCH(type) \ @@ -1521,7 +1664,8 @@ StatusOr> ConvertIfDestTypeMatches( } StatusOr> ConvertSwitch( - const Literal& literal, PrimitiveType primitive_dest_type, bool bitcast) { + const LiteralBase& literal, PrimitiveType primitive_dest_type, + bool bitcast) { TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); if (literal.shape().element_type() == primitive_dest_type) { return literal.CloneToUnique(); @@ -1555,12 +1699,12 @@ StatusOr> ConvertSwitch( } // namespace -StatusOr> Literal::Convert( +StatusOr> LiteralBase::Convert( PrimitiveType primitive_dest_type) const { return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); } -StatusOr> Literal::BitcastConvert( +StatusOr> LiteralBase::BitcastConvert( PrimitiveType primitive_dest_type) const { if (primitive_util::BitWidth(shape().element_type()) != primitive_util::BitWidth(primitive_dest_type)) { @@ -1575,7 +1719,7 @@ StatusOr> Literal::BitcastConvert( return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); } -StatusOr> Literal::ConvertToShape( +StatusOr> LiteralBase::ConvertToShape( const Shape& dest_shape, bool round_f32_to_bf16) const { if (!ShapeUtil::IsTuple(dest_shape)) { if (round_f32_to_bf16 && shape().element_type() == F32 && @@ -1590,7 +1734,7 @@ StatusOr> Literal::ConvertToShape( } std::vector elements; for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { - auto element = LiteralView::Create(*this, {i}); + auto element = LiteralSlice(*this, {i}); TF_ASSIGN_OR_RETURN( auto new_element, element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); @@ -1602,8 +1746,8 @@ StatusOr> Literal::ConvertToShape( } template -bool Literal::Piece::EqualElementsInternal( - const Literal::Piece& other, std::vector* multi_index) const { +bool LiteralBase::Piece::EqualElementsInternal( + const LiteralBase::Piece& other, std::vector* multi_index) const { if (multi_index->size() == ShapeUtil::Rank(subshape())) { return (Get(*multi_index) == other.Get(*multi_index)); } @@ -1617,7 +1761,7 @@ bool Literal::Piece::EqualElementsInternal( return true; } -bool Literal::Piece::EqualElements(const Literal::Piece& other) const { +bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); std::vector multi_index; @@ -1645,28 +1789,28 @@ bool Literal::Piece::EqualElements(const Literal::Piece& other) const { case C64: return EqualElementsInternal(other, &multi_index); default: - LOG(FATAL) << "Unimplemented: Literal::Piece::EqualElements for type " + LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type " << PrimitiveType_Name(subshape().element_type()); } } -bool Literal::operator==(const Literal& other) const { +bool LiteralBase::operator==(const LiteralBase& other) const { if (!ShapeUtil::Compatible(shape(), other.shape())) { return false; } - for (const auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } - const Piece& other_piece = other.piece(index); - if (!piece.EqualElements(other_piece)) { - return false; - } - } - return true; + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + const Piece& other_piece = other.piece(index); + if (!piece.EqualElements(other_piece)) { + return false; + } + return true; + }); } namespace { @@ -1684,11 +1828,11 @@ static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice data, } // namespace -bool Literal::IsAll(int8 value) const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; +bool LiteralBase::IsAll(int8 value) const { + return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index, + const Piece& piece) { if (!ShapeUtil::IsArray(piece.subshape())) { - continue; + return true; } auto piece_is_all = [&]() { @@ -1741,41 +1885,41 @@ bool Literal::IsAll(int8 value) const { if (!piece_is_all()) { return false; } - } - return true; + return true; + }); } -bool Literal::IsAllFloat(float value) const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } +bool LiteralBase::IsAllFloat(float value) const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } - auto piece_is_all = [&]() { - switch (shape().element_type()) { - case F32: - return AllElementsEqualValue(piece.data(), value); - case F64: - return AllElementsEqualValue(piece.data(), value); - case F16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - case BF16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - default: + auto piece_is_all = [&]() { + switch (shape().element_type()) { + case F32: + return AllElementsEqualValue(piece.data(), value); + case F64: + return AllElementsEqualValue(piece.data(), value); + case F16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case BF16: + return AllElementsEqualValue( + piece.data(), static_cast(value)); + default: + return false; + } + }; + if (!piece_is_all()) { return false; - } - }; - if (!piece_is_all()) { - return false; - } - } - return true; + } + return true; + }); } -bool Literal::IsAllComplex(complex64 value) const { +bool LiteralBase::IsAllComplex(complex64 value) const { switch (shape().element_type()) { case C64: return AllElementsEqualValue(root_piece().data(), @@ -1785,93 +1929,93 @@ bool Literal::IsAllComplex(complex64 value) const { } } -bool Literal::IsAllFirst() const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } - - // Empty shapes are not all the first element since there is no first - // element. - if (ShapeUtil::HasZeroElements(piece.subshape())) { - return false; - } - auto piece_is_all = [&]() { - switch (piece.subshape().element_type()) { - case PRED: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 8 bit types - case S8: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U8: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 16 bit types - case BF16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case F16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 32 bit types - case F32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); +bool LiteralBase::IsAllFirst() const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; } - case U32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 64 bit types - case C64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case F64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - default: + + // Empty shapes are not all the first element since there is no first + // element. + if (ShapeUtil::HasZeroElements(piece.subshape())) { return false; - } - }; + } + auto piece_is_all = [&]() { + switch (piece.subshape().element_type()) { + case PRED: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 8 bit types + case S8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 16 bit types + case BF16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 32 bit types + case F32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 64 bit types + case C64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + default: + return false; + } + }; - if (!piece_is_all()) { - return false; - } - } - return true; + if (!piece_is_all()) { + return false; + } + return true; + }); } -bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { +bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice indices) const { CHECK(ShapeUtil::IsArray(shape())); switch (shape().element_type()) { case U8: @@ -1913,7 +2057,7 @@ void CopyToRepeatedField(RepeatedFieldT* dest, } // namespace -void Literal::Piece::WriteToProto(LiteralProto* proto) const { +void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { *proto->mutable_shape() = subshape(); switch (subshape().element_type()) { case PRED: @@ -1969,12 +2113,12 @@ void Literal::Piece::WriteToProto(LiteralProto* proto) const { } } -const void* Literal::Piece::untyped_data() const { +const void* LiteralBase::Piece::untyped_data() const { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); return buffer(); } -void* Literal::Piece::untyped_data() { +void* LiteralBase::Piece::untyped_data() { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); return buffer(); } @@ -1995,7 +2139,7 @@ Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, } // namespace -Status Literal::Piece::CopyFromProto(const LiteralProto& proto) { +Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { // These conditions should have been checked in Literal::CreateFromProto. TF_RET_CHECK(proto.has_shape()); TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); @@ -2062,21 +2206,19 @@ Status Literal::Piece::CopyFromProto(const LiteralProto& proto) { return Status::OK(); } -LiteralProto Literal::ToProto() const { +LiteralProto LiteralBase::ToProto() const { LiteralProto proto; - for (const auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - const Piece& piece = pair.second; - - LiteralProto* proto_piece = &proto; - for (int64 i : index) { - while (proto_piece->tuple_literals_size() <= i) { - proto_piece->add_tuple_literals(); - } - proto_piece = proto_piece->mutable_tuple_literals(i); - } - piece.WriteToProto(proto_piece); - } + root_piece().ForEachSubpiece( + [&](const ShapeIndex& index, const Piece& piece) { + LiteralProto* proto_piece = &proto; + for (int64 i : index) { + while (proto_piece->tuple_literals_size() <= i) { + proto_piece->add_tuple_literals(); + } + proto_piece = proto_piece->mutable_tuple_literals(i); + } + piece.WriteToProto(proto_piece); + }); if (LayoutUtil::IsSparseArray(shape())) { CopyToRepeatedField(proto.mutable_sparse_indices(), @@ -2098,33 +2240,40 @@ StatusOr> Literal::CreateFromProto( auto literal = MakeUnique(proto.shape()); - for (auto& pair : literal->pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - const LiteralProto* proto_element = &proto; - for (int64 i : index) { - TF_RET_CHECK(i < proto_element->tuple_literals_size()); - proto_element = &proto_element->tuple_literals(i); - } + TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + const LiteralProto* proto_element = &proto; + for (int64 i : index) { + CHECK(i < proto_element->tuple_literals_size()); + proto_element = &proto_element->tuple_literals(i); + } - if (ShapeUtil::IsTuple(piece.subshape())) { - if (proto_element->tuple_literals_size() != - ShapeUtil::TupleElementCount(piece.subshape())) { - return InvalidArgument( - "Expected %lld tuple elements in LiteralProto, has %d", - ShapeUtil::TupleElementCount(piece.subshape()), - proto_element->tuple_literals_size()); - } - continue; - } + if (ShapeUtil::IsTuple(piece->subshape())) { + if (proto_element->tuple_literals_size() != + ShapeUtil::TupleElementCount(piece->subshape())) { + return InvalidArgument( + "Expected %lld tuple elements in LiteralProto, has %d", + ShapeUtil::TupleElementCount(piece->subshape()), + proto_element->tuple_literals_size()); + } + return Status::OK(); + } + + CHECK(ShapeUtil::IsArray(piece->subshape())); + TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); + + return Status::OK(); + })); - TF_RET_CHECK(ShapeUtil::IsArray(piece.subshape())); - TF_RETURN_IF_ERROR(piece.CopyFromProto(*proto_element)); - } return std::move(literal); } -const void* Literal::untyped_data(const ShapeIndex& shape_index) const { +/* static */ string Literal::MultiIndexAsString( + tensorflow::gtl::ArraySlice multi_index) { + return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}"); +} + +const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const { return piece(shape_index).untyped_data(); } @@ -2132,11 +2281,11 @@ void* Literal::untyped_data(const ShapeIndex& shape_index) { return piece(shape_index).untyped_data(); } -int64 Literal::size_bytes(const ShapeIndex& shape_index) const { +int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const { return piece(shape_index).size_bytes(); } -string Literal::GetR1U8AsString() const { +string LiteralBase::GetR1U8AsString() const { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(shape().element_type(), U8); @@ -2144,72 +2293,55 @@ string Literal::GetR1U8AsString() const { ShapeUtil::ElementsIn(shape())); } -/* static */ const LiteralView LiteralView::Create( - const Literal& literal, const ShapeIndex& view_root) { - return LiteralView(literal, view_root); -} - -size_t Literal::Hash() const { - using tensorflow::Hash64; - using tensorflow::Hash64Combine; - - size_t hash_value = ShapeUtil::Hash(shape()); +void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { + CHECK(ShapeUtil::IsTuple(shape)); + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); - ShapeUtil::ForEachSubshape( - shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (ShapeUtil::IsTuple(subshape)) { - return; - } + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); - CHECK(LayoutUtil::IsDense(subshape.layout())); - hash_value = Hash64Combine( - hash_value, Hash64(static_cast(untyped_data(index)), - size_bytes(index))); - }); - - return hash_value; -} - -LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) { - shape_ = ShapeUtil::GetSubshape(literal.shape(), view_root); - pieces_ = ShapeTree(shape_); - owns_buffers_ = false; - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - - ShapeIndex src_index = view_root; - for (int64 i : index) { - src_index.push_back(i); + if (ShapeUtil::IsTuple(subshape)) { + BuildPieceSubtree(subshape, &child_piece); } - const Piece& src_piece = literal.piece(src_index); - piece.set_buffer(src_piece.buffer()); - piece.set_sparse_indices(src_piece.sparse_indices()); - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); + + piece->emplace_back(std::move(child_piece)); } } -LiteralView::~LiteralView() {} +LiteralSlice::LiteralSlice(const LiteralBase& literal) + : LiteralBase(), root_piece_(&literal.root_piece()) {} -LiteralView::LiteralView(const LiteralView& other) { CopyFrom(other); } +LiteralSlice::LiteralSlice(const LiteralBase& literal, + const ShapeIndex& view_root) + : LiteralBase(), root_piece_(&literal.piece(view_root)) {} -LiteralView& LiteralView::operator=(const LiteralView& other) { - CopyFrom(other); - return *this; +BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) + : LiteralBase(), shape_(shape) { + CHECK(ShapeUtil::IsArray(shape_)); + CHECK_NE(src_buf_ptr, nullptr); + CHECK(LayoutUtil::HasLayout(shape_)); + + root_piece_ = Piece(); + root_piece_.set_buffer(const_cast(src_buf_ptr)); + root_piece_.set_subshape(&shape_); } -void LiteralView::CopyFrom(const LiteralView& other) { - // We can't use the default copy-constructor/copy-assignment because - // Piece::subshape_ points to subshapes within the Shape of the owning - // Literal/LiteralView. - shape_ = other.shape(); - pieces_ = other.pieces_; - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); +BorrowingLiteral::BorrowingLiteral( + tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) + : LiteralBase(), shape_(shape) { + CHECK(ShapeUtil::IsTuple(shape_)); + CHECK(!ShapeUtil::IsNestedTuple(shape_)); + CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(shape_)); + root_piece_ = Piece(); + root_piece_.set_subshape(&shape_); + BuildPieceSubtree(shape_, &root_piece_); + + for (int i = 0; i < src_buf_ptrs.size(); ++i) { + const auto& src_shape = shape_.tuple_shapes(i); + CHECK(ShapeUtil::IsArray(src_shape)); + root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); } - owns_buffers_ = false; } } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index c6bd03bf21ac8dc88e96856cffe02c758e7b996d..609dc7a3aca646a5bb787487de101ac115df8ea5 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -52,14 +51,497 @@ limitations under the License. namespace xla { +// Forward declare Literal and LiteralSlice class to be used by the creation +// methods in the base class. +class Literal; +class LiteralSlice; + +// Abstract base class for literals. +class LiteralBase { + public: + virtual ~LiteralBase() = 0; + + // Literals are equal if they have compatible shapes and the same data + // values. Layout is not compared. + bool operator==(const LiteralBase& other) const; + bool operator!=(const LiteralBase& other) const { return !(*this == other); } + + // Returns the shape of the literal. + const Shape& shape() const { return root_piece().subshape(); } + + // Serialize to proto. + LiteralProto ToProto() const; + + // Returns an ArraySlice of the array for this literal for the given NativeT + // (e.g., float). CHECKs if the subshape of the literal at the given + // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type + // to native type. + template + tensorflow::gtl::ArraySlice data( + const ShapeIndex& shape_index = {}) const; + + // Returns a const pointer to the sparse index array. Returns nullptr if the + // literal is not a sparse array. + const SparseIndexArray* sparse_indices( + const ShapeIndex& shape_index = {}) const; + + // Returns a const pointer to (or size of) the underlying buffer holding the + // array at the given shape index. CHECKs if the subshape of the literal at + // the given ShapeIndex is not array. + const void* untyped_data(const ShapeIndex& shape_index = {}) const; + int64 size_bytes(const ShapeIndex& shape_index = {}) const; + + // Returns this literal's data as a string. This literal must be a rank-1 U8 + // array. + string GetR1U8AsString() const; + + // Returns a string representation of the literal value. + // Warning: this function can take minutes for multi-million element Literals. + string ToString(bool print_layout = false) const; + + // Gets an element in the literal at the given index. The multi_index is + // CHECKed against the dimension sizes. + template + NativeT Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const; + // Overloads of Get for array literals. CHECKs if the literal is not + // array-shaped and dense. + template + NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; + + // Returns the element value at index (0, ..., 0), however many zeroes are + // required for that index. + template + NativeT GetFirstElement() const; + + // As Get(), but determines the correct type and converts the value + // into text. + string GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index = {}) const; + // As GetSparseElement(), but determines the correct type and converts the + // value into text. + string GetSparseElementAsString(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + // As Get(), but determines the correct type and converts the value into + // int64. This literal must be an array. + StatusOr GetIntegralAsS64( + tensorflow::gtl::ArraySlice multi_index) const; + + // Returns the multi-index of the element in a sparse literal at the given + // sparse element number. The sparse element number is the position with in + // the sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + tensorflow::gtl::ArraySlice GetSparseIndex( + int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; + + // Returns the value of the element in a sparse literal at the given sparse + // element number. The sparse element number is the position with in the + // sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + template + NativeT GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + + // Invokes the "per cell" callback for each element in the provided + // literal with the element's indices and a string representation of + // the element's value. + // + // This function is useful if you want a polymorphic representation + // of the tensor's elements (turning it to a string for something + // like representation in a protobuf). + // + // This literal must have a dense layout. + void EachCellAsString( + const std::function indices, + const string& value)>& per_cell) const; + template + void EachCell(std::function indices, + NativeT value)> + per_cell) const; + + // Returns whether every element in this literal is equal to value. + // + // value is an int8 because we expect this to be called with small + // compile-time constants (0, -1, etc.) and so that whatever value you pass + // can be represented exactly by floating-point types as small as 16 bits. + // + // If value doesn't fit in this literal's type, returns false. Values of 1/0 + // are considered equal to true/false; other values are not considered equal + // to true. Also if this literal is not array-shaped false is returned. + bool IsAll(int8 value) const; + + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular floating-point number. + // + // If the literal is not a floating-point value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for values that can be expressed precisely as a float, + // e.g. -0.5. Also if this literal is not array-shaped false is returned. + bool IsAllFloat(float value) const; + + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular complex number. + // + // If the literal is not a complex value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for complex values that can be expressed precisely as + // float pairs e.g. (-0.5, 1.0). + // + // This literal must have a dense layout. + bool IsAllComplex(complex64 value) const; + + // Literal consists entirely of the first element of the literal. + bool IsAllFirst() const; + + // Returns whether this literal is zero at the specified index. This literal + // must be an array with a dense layout. + bool IsZero(tensorflow::gtl::ArraySlice indices) const; + + // Returns the count of the elements in the array at the given shape index in + // this literal. + int64 element_count(const ShapeIndex& index = {}) const { + return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); + } + + // Returns the count of the elements in the sparse array at the given shape + // index in this literal, which will be no larger than + // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). + int64 sparse_element_count() const; + + // Compute a hash for this literal. This literal must not be a sparse tensor + // or a tuple containing a sparse tensor. + size_t Hash() const; + + // Converts this literal to the given shape. Returns an error is the + // conversion is not possible. + // + // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding + // instead of truncation; otherwise, truncation is used. + // + // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes + // the default behavior. + StatusOr> ConvertToShape( + const Shape& dest_shape, bool round_f32_to_bf16 = false) const; + + // Converts this literal to another primitive type using a bitcast + // conversion. The to and from primitive types must have the same bit + // width. Returns an error if the conversion is not possible. This literal + // must be array-shaped. + StatusOr> BitcastConvert( + PrimitiveType primitive_dest_type) const; + + // Converts this literal to another primitive type. Returns an error if the + // conversion is not possible. This literal must be array-shaped. + StatusOr> Convert( + PrimitiveType primitive_dest_type) const; + + // Returns a literal scalar representing the first element. + Literal GetFirstScalarLiteral() const; + + // Clones the underlying buffers into a new Literal, or new + // std::unique_ptr. + Literal Clone() const; + std::unique_ptr CloneToUnique() const; + + // TODO(b/67651157): The methods below which perform computation on Literals + // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with + // evaluator code which operates on Literals. + // + // Creates a new value that has the equivalent value as this + // literal, but conforms to new_layout; e.g. a literal matrix that was in {0, + // 1} minor-to-major dimension layout can be re-layed-out as {1, 0} + // minor-to-major dimension layout and the value in the cell at any given + // logical index (i0, i1) will be the same. + // + // For tuple shaped literals, shape_index should be used to select the inner + // array that the new layout applies to. + // + // Note: this is useful when the client wants to ensure that a value placed in + // the XLA allocation tracker has a particular layout; for efficiency + // purposes or avoiding unimplemented operation/layout combinations. + std::unique_ptr Relayout(const Layout& new_layout, + const ShapeIndex& shape_index = {}) const; + + // An overload of Relayout which changes the layout of the entire shape rather + // than being limited to a single array within the shape. + std::unique_ptr Relayout(const Shape& shape_with_layout) const; + + // Creates a new literal by reshaping this literal to have the given + // dimensions. The total number of elements must not change; The + // implementation currently only supports monotonic dim0-major layouts. + // This literal must be an array. + StatusOr> Reshape( + tensorflow::gtl::ArraySlice dimensions) const; + + // Creates a new literal by reordering the dimensions of this literal. + // The given `permutation` must be a permutation of the dimension numbers + // in the original literal, and it specifies the order of the new dimensions + // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). + // For example, a transpose call on a literal of shape [3 x 8 x 4] and + // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. + // This literal must be an array. + std::unique_ptr Transpose( + tensorflow::gtl::ArraySlice permutation) const; + + // Creates a sub-array from this literal by extracting the indices + // [start_index, limit_index) of each dimension. The result literal has the + // same rank and layout as for the given literal. The number of indices in + // start_indices and limit_indices must be the rank of the literal, and the + // indices follow the order of the dimensions. + // This literal must be an array. + std::unique_ptr Slice( + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) const; + + // Creates a literal with a prepended dimension with bound "times"; e.g. a + // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this + // literal replicated four times. + // This literal must be an array. + template + std::unique_ptr Replicate(int64 times) const; + + // Creates a new Literal object with the shape specified as parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + // + // Note: It's an antipattern to use this method then immediately call + // Literal::Populate on the result (since that results in zero initialization, + // then reinitialization. Conside if a call to MakeUnique(shape), + // followed by the call to Literal::Populate can be used instead. + static std::unique_ptr CreateFromShape(const Shape& shape); + + protected: + // A data structure representing a subshape at a particular ShapeIndex within + // the literal. For array-shaped ShapeIndexes, this data structure holds the + // pointer to the memory allocated for the array data. + class Piece { + public: + // Returns the buffer holding the array data for this piece as an array + // slice. This piece must be array-shaped. + template + tensorflow::gtl::ArraySlice data() const; + template + tensorflow::gtl::MutableArraySlice data(); + + // Returns the buffer holding the array data for this piece as a void*. This + // piece must be array-shaped. + void* untyped_data(); + const void* untyped_data() const; + + // Gets or sets an element in the array at the given index. The multi_index + // is CHECKed against the dimension sizes of the array. This piece must be + // array-shaped. + template + NativeT Get(tensorflow::gtl::ArraySlice index) const; + template + void Set(tensorflow::gtl::ArraySlice index, NativeT value); + + // Gets/sets the buffer holding the array data. + char* buffer() const { return buffer_; } + void set_buffer(char* buffer) { buffer_ = buffer; } + + // The array of multi-indices that provide the locations of non-zero + // elements in a sparse array. Only used if + // LayoutUtil::IsSparseArray(shape()) is true. + SparseIndexArray* sparse_indices() const { return sparse_indices_; } + void set_sparse_indices(SparseIndexArray* sparse_indices) { + sparse_indices_ = sparse_indices; + } + + // Gets or sets the subshape of this piece. This reference points to a + // subshape within the shape in the containing Literal (Literal::shape_). + const Shape& subshape() const { return *subshape_; } + void set_subshape(const Shape* subshape) { subshape_ = subshape; } + + // Returns the size in bytes of the buffer holding the array data. + int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } + + // Returns the number of elements in this piece's array. + int64 element_count() const { + // If this is a sparse array, use the number of elements represented by + // the indices in the associated SparseIndexArray. + return LayoutUtil::IsSparseArray(subshape()) + ? sparse_indices()->index_count() + : ShapeUtil::ElementsIn(subshape()); + } + + // Returns the child piece at 'index' of this piece. + Piece& child(int64 index) { return children_[index]; } + + // Adds a child piece to this piece's children. + void emplace_back(Piece child_piece) { + children_.emplace_back(std::move(child_piece)); + } + + // Returns the size of children pieces of this piece. + int64 children_size() { return children_.size(); } + + // Visitor functions that recursively traverses the piece and calls the + // given function at each child piece. The function has the type: + // void (const ShapeIndex& index, const Piece& piece) + template + void ForEachSubpiece(const Fn& func) const { + ShapeIndex index; + return ForEachHelper( + [&func](const ShapeIndex& index, const Piece& piece) { + func(index, piece); + return Status::OK(); + }, + *this, &index) + .IgnoreError(); + } + // Same as above, but the function has the type: + // Status (const ShapeIndex& index, const Piece& piece) + // The first non-OK return value is returned by the function. + template + Status ForEachSubpieceWithStatus(const Fn& func) const { + ShapeIndex index; + return ForEachHelper(func, *this, &index); + } + // Same as above, but the function has the type: + // Bool (const ShapeIndex& index, const Piece& piece) + // The first non-true return value is returned by the function. + template + bool ForEachSubpieceWithBool(const Fn& func) const { + ShapeIndex index; + return ForEachHelperBool(func, *this, &index); + } + // Same as above, but the function has the type: + // Void (const ShapeIndex& index, Piece& piece) + template + void ForEachMutableSubpiece(const Fn& func) { + ShapeIndex index; + return ForEachMutableHelper( + [&func](const ShapeIndex& index, Piece* piece) { + func(index, piece); + return Status::OK(); + }, + const_cast(this), &index) + .IgnoreError(); + } + // Same as above, but the function has the type: + // Status (const ShapeIndex& index, Piece& piece) + // The first non-OK return value is returned by the function. + template + Status ForEachMutableSubpieceWithStatus(const Fn& func) { + ShapeIndex index; + return ForEachMutableHelper( + func, const_cast(this), &index); + } + + // Returns true if this piece and 'other' contain the same data. This piece + // and 'other' must be array-shaped and compatible. + bool EqualElements(const Piece& other) const; + + // Writes the shape and data (if array-shaped) into the given proto. + void WriteToProto(LiteralProto* proto) const; + + // Copy the data from 'src' into this piece's buffer. Shapes of this piece + // and src must be compatible. + Status CopyFrom(const Piece& src); + + // Copies the data from the given proto into this piece. The shape of this + // piece must be equal (not just compatible) to the shape of the proto. + Status CopyFromProto(const LiteralProto& proto); + + // Sorts the elements in a sparse array. + void SortSparseElements(); + + private: + // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'. + // The first non-OK (or non-true) value is returned by the function. + // The callable 'func' has the same signature as described above in + // ForEachSubpiece*. + template + Status ForEachHelper(const Fn& func, const Piece& piece, + ShapeIndex* index) const { + TF_RETURN_IF_ERROR(func(*index, piece)); + for (int64 i = 0; i < piece.children_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index)); + index->pop_back(); + } + return Status::OK(); + } + template + bool ForEachHelperBool(const Fn& func, const Piece& piece, + ShapeIndex* index) const { + if (!func(*index, piece)) { + return false; + } + for (int64 i = 0; i < piece.children_.size(); ++i) { + index->push_back(i); + if (!ForEachHelperBool(func, piece.children_[i], index)) { + return false; + } + index->pop_back(); + } + return true; + } + template + Status ForEachMutableHelper(const Fn& func, Piece* piece, + ShapeIndex* index) { + TF_RETURN_IF_ERROR(func(*index, piece)); + for (int64 i = 0; i < piece->children_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR( + ForEachMutableHelper(func, &piece->children_[i], index)); + index->pop_back(); + } + return Status::OK(); + } + + // Recursive helper for EqualElements. + template + bool EqualElementsInternal(const Piece& other, + std::vector* multi_index) const; + + // Helper for SortSparseElements that has the element type as a template + // parameter. + template + void SortSparseElementsInternal(); + + // For array-shaped pieces, this is the buffer holding the literal data. + char* buffer_ = nullptr; + + // For sparse arrays, this is the array of indices. + SparseIndexArray* sparse_indices_ = nullptr; + + // The shape of piece. This points into the shape of the containing Literal + // (Literal::shape_). + const Shape* subshape_ = nullptr; + + // Children pieces for tuple shaped pieces. + std::vector children_ = {}; + }; // class Piece + + const Piece& piece(const ShapeIndex& shape_index) const { + Piece* piece = &const_cast(root_piece()); + for (const auto i : shape_index) { + DCHECK_GE(i, 0); + DCHECK_LT(i, piece->children_size()); + piece = &piece->child(i); + } + return *piece; + } + + // Returns the piece at the root of the shape. + virtual const Piece& root_piece() const = 0; + + // LiteralSlice and Literal must access Pieces of other Literals. + friend class Literal; + friend class LiteralSlice; + friend class BorrowingLiteral; +}; + // Class representing literal values in XLA. // -// TODO(b/67651157): The methods in this class should be reduced to a minimal -// set of methods which construct Literals and accessors methods. Other methods -// which perform computation on Literals (Reshape, Slice, etc) should be moved -// elsewhere, and perhaps combined with evaluator code which operates on -// Literals. -class Literal { +// The underlying buffer and shape is always owned by this class. +class Literal : public LiteralBase { public: Literal() : Literal(ShapeUtil::MakeNil()) {} @@ -80,46 +562,156 @@ class Literal { Literal(const Shape& shape, bool allocate_arrays); Literal& operator=(Literal&& other); - // Literals are equal if they have compatible shapes and the same data - // values. Layout is not compared. - bool operator==(const Literal& other) const; - bool operator!=(const Literal& other) const { return !(*this == other); } + // TODO(b/67651157): Remove this accessor. Literal users should not be able to + // mutate the shape as this can produce malformed Literals. + Shape* mutable_shape_do_not_use() { return shape_.get(); } - // Serialize to and from a proto. - static StatusOr> CreateFromProto( - const LiteralProto& proto); - LiteralProto ToProto() const; + // Returns a MutableArraySlice view of the array for this literal for the + // given NativeT (e.g., float). CHECKs if the subshape of the literal at the + // given ShapeIndex is not array. See primitive_util.h for the mapping from + // XLA type to native type. + template + tensorflow::gtl::MutableArraySlice data( + const ShapeIndex& shape_index = {}); + // Unhide const method from parent class. + using LiteralBase::data; + + // Returns a pointer to the sparse index array. Returns nullptr if the literal + // is not a sparse array. + SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + + // Returns a pointer to the underlying buffer holding the array at the given + // shape index. CHECKs if the subshape of the literal at the given ShapeIndex + // is not array. + void* untyped_data(const ShapeIndex& shape_index = {}); + // Unhide const method from parent class. + using LiteralBase::untyped_data; + + // Populates a literal with a sparse layout with the given indices and values. + // Each index in the indices array is CHECKed against the dimensions in the + // literal's shape. If sort is true, then the indices and values will be + // sorted. If sort is false, then the indices and values are assumed to + // already be in sorted order. See CreateSparse for an example of how data + // are populated. + template + void PopulateSparse(SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, + bool sort = true); + + // Copy values from 'src_literal' rooted at 'src_shape_index' into this + // literal rooted at 'dest_shape_index'. The subshape of this literal rooted + // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' + // rooted at 'src_shape_index', but need not be arrays. + Status CopyFrom(const LiteralSlice& src_literal, + const ShapeIndex& dest_shape_index = {}, + const ShapeIndex& src_shape_index = {}); + + // Similar to CopyFrom, but with move semantincs. The subshape of this literal + // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' + // (layouts and shapes must match), but need not be arrays. The memory + // allocated in this literal for the subshape at dest_shape_index is + // deallocated, and the respective buffers are replaced with those in + // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). + Status MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index = {}); + + // Copies the values from src_literal, starting at src_base shape indexes, + // to this literal, starting at dest_base, where the copy size in each + // dimension is specified by copy_size. + // The src_literal and this literal must have the same primitive type, + // src_base+copy_size must fit the source literal dimensions, as well as + // dest_base+copy_size must fit the destination literal dimensions. + // Note: if either src_literal or this literal contains dimensions with zero + // element, then copy_size must be 0 in these dimensions while the + // corresponding base indices being 0. + // This literal and 'src_literal' must be arrays. + Status CopySliceFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); - // Return the shape of the literal. - const Shape& shape() const { return shape_; } + // Copies one element from src_literal[src_index] to (*this)[dest_index]. + Status CopyElementFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_index, + tensorflow::gtl::ArraySlice dest_index); + + // Sets an element in the literal at the given index. The multi_index is + // CHECKed against the dimension sizes. + template + void Set(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index, NativeT value); + // Overloads of Set for array literals. CHECKs if the literal is not + // array-shaped and dense. + template + void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); + + // Appends the given element to the literal. If the elements are not appended + // in sorted order, then SortSparseElements should be called before calling + // other methods. This literal must have a sparse layout. + template + void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, + NativeT value, const ShapeIndex& shape_index = {}); + + // Sorts the elements in a sparse array. + void SortSparseElements(const ShapeIndex& shape_index = {}); + + // As Set(), but truncates `value` to the literal element type before storing. + // This literal must be an array. + Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, + int64 value); + + // Populate this literal with the given values. Examples: + // + // // Populate with floats. + // Array2D float_values = ... + // literal.PopulateR2FromArray2D(values); + // + // // Populate with int32s. + // literal.PopulateR2({{1, 2}, {3, 4}}); + // + // The shape and element type of this literal must match given values. For + // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 + // array of S32. + template + void PopulateR1(tensorflow::gtl::ArraySlice values); + void PopulateR1(const tensorflow::core::Bitmap& values); + template + void PopulateR2(std::initializer_list> values); + template + void PopulateFromArray(const Array& values); + template + void PopulateR2FromArray2D(const Array2D& values); + template + void PopulateR3FromArray3D(const Array3D& values); + template + void PopulateR4FromArray4D(const Array4D& values); + + // Populates literal values by calling the generator function for every cell + // in this literal object. + // + // generator must be a callable of the type + // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. + // + // This literal must have a dense layout. + template + Status Populate(const FnType& generator); - // TODO(b/67651157): Remove this accessor. Literal users should not be able to - // mutate the shape as this can produce malformed Literals. - Shape* mutable_shape_do_not_use() { return &shape_; } + // A parallel version of Populate(). This can be used if the generator is + // thread-safe and the values for the shape's different elements are + // independent. + template + Status PopulateParallel(const FnType& generator); - // Returns a (Mutable)ArraySlice view of the array for this literal for the - // given NativeT (e.g., float). CHECKs if the subshape of the literal at the - // given ShapeIndex is not array. See primitive_util.h for the mapping from - // XLA type to native type. - template - tensorflow::gtl::ArraySlice data( - const ShapeIndex& shape_index = {}) const; + // Fills this literal with the given value. template - tensorflow::gtl::MutableArraySlice data( - const ShapeIndex& shape_index = {}); + void PopulateWithValue(NativeT value); - // Returns a pointer to the sparse index array. Returns nullptr if the literal - // is not a sparse array. - const SparseIndexArray* sparse_indices( - const ShapeIndex& shape_index = {}) const; - SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + // Factory methods below. + // - // Returns a pointer to (or size of) the underlying buffer holding the array - // at the given shape index. CHECKs if the subshape of the literal at the - // given ShapeIndex is not array. - const void* untyped_data(const ShapeIndex& shape_index = {}) const; - void* untyped_data(const ShapeIndex& shape_index = {}); - int64 size_bytes(const ShapeIndex& shape_index = {}) const; + // Serialize from a proto. + static StatusOr> CreateFromProto( + const LiteralProto& proto); // Creates a new literal of a given rank. To minimize ambiguity (for users // and the compiler) these CreateR[0-2] methods should explicitly specify the @@ -167,10 +759,6 @@ class Literal { values, const Layout& layout); - // Returns this literal's data as a string. This literal must be a rank-1 U8 - // array. - string GetR1U8AsString() const; - // Creates a literal with a sparse layout and the given indices and values. // The shape is initialized from the given dimensions. The minor dimension of // the indices array must equal the rank of the shape (i.e. size of the @@ -210,171 +798,16 @@ class Literal { tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, tensorflow::gtl::ArraySlice values, bool sort = true); - // Populates a literal with a sparse layout with the given indices and values. - // Each index in the indices array is CHECKed against the dimensions in the - // literal's shape. If sort is true, then the indices and values will be - // sorted. If sort is false, then the indices and values are assumed to - // already be in sorted order. See CreateSparse for an example of how data - // are populated. - template - void PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, - bool sort = true); - - // Creates a new Literal object with the shape specified as parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromShape(const Shape& shape); - - // Creates a new Literal object with its values havings the primitive_type - // type, and with dimensions defined by the dimensions parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromDimensions( - PrimitiveType primitive_type, - tensorflow::gtl::ArraySlice dimensions); - - // Copy values from 'src_literal' rooted at 'src_shape_index' into this - // literal rooted at 'dest_shape_index'. The subshape of this literal rooted - // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' - // rooted at 'src_shape_index', but need not be arrays. - Status CopyFrom(const Literal& src_literal, - const ShapeIndex& dest_shape_index = {}, - const ShapeIndex& src_shape_index = {}); - - // Similar to CopyFrom, but with move semantincs. The subshape of this literal - // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' - // (layouts and shapes must match), but need not be arrays. The memory - // allocated in this literal for the subshape at dest_shape_index is - // deallocated, and the respective buffers are replaced with those in - // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). - Status MoveFrom(Literal&& src_literal, - const ShapeIndex& dest_shape_index = {}); - - // Copies the values from src_literal, starting at src_base shape indexes, - // to this literal, starting at dest_base, where the copy size in each - // dimension is specified by copy_size. - // The src_literal and this literal must have the same primitive type, - // src_base+copy_size must fit the source literal dimensions, as well as - // dest_base+copy_size must fit the destination literal dimensions. - // Note: if either src_literal or this literal contains dimensions with zero - // element, then copy_size must be 0 in these dimensions while the - // corresponding base indices being 0. - // This literal and 'src_literal' must be arrays. - Status CopySliceFrom(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); - - // Copies one element from src_literal[src_index] to (*this)[dest_index]. - Status CopyElementFrom(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index); - - // Returns a vector containing the tuple elements of this Literal as separate - // Literals. This Literal must be tuple-shaped and can be a nested tuple. The - // elements are moved into the new Literals; no data is copied. Upon return - // this Literal is set to a nil shape (empty tuple) - std::vector DecomposeTuple(); - - // This operation is the inverse of DecomposeTuple. The given elements are - // moved into the tuple elements of a new tuple-shaped Literal which is - // returned. Upon return, each of the Literals in 'elements' is set to a nil - // shape (empty tuple). - static Literal MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements); - - // Creates a new value that has the equivalent value as this literal, but - // conforms to new_layout; e.g. a literal matrix that was in {0, 1} - // minor-to-major dimension layout can be re-laid-out as {1, 0} - // minor-to-major dimension layout and the value in the cell at any given - // logical index (i0, i1) will be the same. - // - // For tuple shaped literals, shape_index should be used to select the inner - // array that the new layout applies to. - // - // Note: this is useful when the client wants to ensure that a value placed in - // the XLA allocation tracker has a particular layout; for efficiency - // purposes or avoiding unimplemented operation/layout combinations. - std::unique_ptr Relayout(const Layout& new_layout, - const ShapeIndex& shape_index = {}) const; - - // An overload of Relayout which changes the layout of the entire shape rather - // than being limited to a single array within the shape. - std::unique_ptr Relayout(const Shape& shape_with_layout) const; - - // Creates a new literal by reshaping this literal to have the given - // dimensions. The total number of elements must not change; The - // implementation currently only supports monotonic dim0-major layouts. - // This literal must be an array. - StatusOr> Reshape( - tensorflow::gtl::ArraySlice dimensions) const; - - // Creates a new literal by reordering the dimensions of this literal. - // The given `permutation` must be a permutation of the dimension numbers - // in the original literal, and it specifies the order of the new dimensions - // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). - // For example, a transpose call on a literal of shape [3 x 8 x 4] and - // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. - // This literal must be an array. - std::unique_ptr Transpose( - tensorflow::gtl::ArraySlice permutation) const; - - // Creates a sub-array from this literal by extracting the indices - // [start_index, limit_index) of each dimension. The result literal has the - // same rank and layout as for the given literal. The number of indices in - // start_indices and limit_indices must be the rank of the literal, and the - // indices follow the order of the dimensions. - // This literal must be an array. - std::unique_ptr Slice( - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) const; - - // Creates a literal with a prepended dimension with bound "times"; e.g. a - // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this - // literal replicated four times. - // This literal must be an array. - template - std::unique_ptr Replicate(int64 times) const; - - // Converts this literal to another primitive type using - // static_cast<>. Returns an error if the conversion is not possible. This - // literal must be array-shaped. - StatusOr> Convert( - PrimitiveType primitive_dest_type) const; - - // Converts this literal to another primitive type using a bitcast - // conversion. The to and from primitive types must have the same bit - // width. Returns an error if the conversion is not possible. This literal - // must be array-shaped. - StatusOr> BitcastConvert( - PrimitiveType primitive_dest_type) const; - - // Converts this literal to the given shape. Returns an error is the - // conversion is not possible. - // - // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding - // instead of truncation; otherwise, truncation is used. - // - // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes - // the default behavior. - StatusOr> ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16 = false) const; - // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); - // Creates a scalar literal value one of the given primitive type. static Literal One(PrimitiveType primitive_type); - // Creates a scalar literal value containing the minimum value of the given // primitive type. For floating-point types, returns -inf. static Literal MinValue(PrimitiveType primitive_type); - // Creates a scalar literal value containing the maximum value of the given // primitive type. For floating-point types, returns inf. static Literal MaxValue(PrimitiveType primitive_type); - // Creates a literal of the given shape where each element is `value`. template static std::unique_ptr CreateFullWithDescendingLayout( @@ -423,84 +856,11 @@ class Literal { int64 projection); // Creates a literal that projects the (x, y) dimensions given in values into - // the z and p dimensions given. - template - static std::unique_ptr CreateR4Projected( - std::initializer_list> values, - int64 projection_p, int64 projection_z); - - // Clones this literal into a new Literal, or new std::unique_ptr. - Literal Clone() const; - std::unique_ptr CloneToUnique() const; - - // Gets or sets an element in the literal at the given index. The multi_index - // is CHECKed against the dimension sizes. - template - NativeT Get(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const; - template - void Set(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value); - - // Overloads of Get and Set for array literals. CHECKs if the literal is not - // array-shaped and dense. - template - NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; - template - void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); - - // Returns the multi-index of the element in a sparse literal at the given - // sparse element number. The sparse element number is the position with in - // the sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - tensorflow::gtl::ArraySlice GetSparseIndex( - int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; - - // Returns the value of the element in a sparse literal at the given sparse - // element number. The sparse element number is the position with in the - // sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - template - NativeT GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; - - // Appends the given element to the literal. If the elements are not appended - // in sorted order, then SortSparseElements should be called before calling - // other methods. This literal must have a sparse layout. - template - void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, - NativeT value, const ShapeIndex& shape_index = {}); - - // Sorts the elements in a sparse array. - void SortSparseElements(const ShapeIndex& shape_index = {}); - - // Returns the element value at index (0, ..., 0), however many zeroes are - // required for that index. - template - NativeT GetFirstElement() const; - - // Returns a literal scalar representing the first element. - Literal GetFirstScalarLiteral() const; - - // As Get(), but determines the correct type and converts the value - // into text. - string GetAsString(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index = {}) const; - - // As GetSparseElement(), but determines the correct type and converts the - // value into text. - string GetSparseElementAsString(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; - - // As Get(), but determines the correct type and converts the value into - // int64. This literal must be an array. - StatusOr GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const; - - // As Set(), but truncates `value` to the literal element type before storing. - // This literal must be an array. - Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, - int64 value); + // the z and p dimensions given. + template + static std::unique_ptr CreateR4Projected( + std::initializer_list> values, + int64 projection_p, int64 projection_z); // Returns an identity matrix (rank 2) with the given row and column count. template @@ -511,6 +871,9 @@ class Literal { static std::unique_ptr MakeTuple( tensorflow::gtl::ArraySlice elements); + static std::unique_ptr MakeTupleFromSlices( + tensorflow::gtl::ArraySlice elements); + // As above, but intended to be invoked with move semantics; i.e. // // std::vector> elements = ...; @@ -542,135 +905,104 @@ class Literal { return MakeTupleOwned(std::move(v)); } - // Returns a string representation of the literal value. - // Warning: this function can take minutes for multi-million element Literals. - string ToString(bool print_layout = false) const; - - // Invokes the "per cell" callback for each element in the provided - // literal with the element's indices and a string representation of - // the element's value. - // - // This function is useful if you want a polymorphic representation - // of the tensor's elements (turning it to a string for something - // like representation in a protobuf). - // - // This literal must have a dense layout. - void EachCellAsString( - const std::function indices, - const string& value)>& per_cell) const; - template - void EachCell(std::function indices, - NativeT value)> - per_cell) const; - - // Populate this literal with the given values. Examples: - // - // // Populate with floats. - // Array2D float_values = ... - // literal.PopulateR2FromArray2D(values); - // - // // Populate with int32s. - // literal.PopulateR2({{1, 2}, {3, 4}}); - // - // The shape and element type of this literal must match given values. For - // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 - // array of S32. - template - void PopulateR1(tensorflow::gtl::ArraySlice values); - void PopulateR1(const tensorflow::core::Bitmap& values); - template - void PopulateR2(std::initializer_list> values); - template - void PopulateFromArray(const Array& values); - template - void PopulateR2FromArray2D(const Array2D& values); - template - void PopulateR3FromArray3D(const Array3D& values); - template - void PopulateR4FromArray4D(const Array4D& values); - - // Populates literal values by calling the generator function for every cell - // in this literal object. - // - // generator must be a callable of the type - // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. - // - // This literal must have a dense layout. - template - Status Populate(const FnType& generator); - - // A parallel version of Populate(). This can be used if the generator is - // thread-safe and the values for the shape's different elements are - // independent. - template - Status PopulateParallel(const FnType& generator); + // Returns a vector containing the tuple elements of this Literal as separate + // Literals. This Literal must be tuple-shaped and can be a nested tuple. The + // elements are moved into the new Literals; no data is copied. Upon return + // this Literal is set to a nil shape (empty tuple) + std::vector DecomposeTuple(); - // Fills this literal with the given value. - template - void PopulateWithValue(NativeT value); + // This operation is the inverse of DecomposeTuple. The given elements are + // moved into the tuple elements of a new tuple-shaped Literal which is + // returned. Upon return, each of the Literals in 'elements' is set to a nil + // shape (empty tuple). + static Literal MoveIntoTuple( + tensorflow::gtl::MutableArraySlice elements); - // Returns whether every element in this literal is equal to value. - // - // value is an int8 because we expect this to be called with small - // compile-time constants (0, -1, etc.) and so that whatever value you pass - // can be represented exactly by floating-point types as small as 16 bits. - // - // If value doesn't fit in this literal's type, returns false. Values of 1/0 - // are considered equal to true/false; other values are not considered equal - // to true. Also if this literal is not array-shaped false is returned. - bool IsAll(int8 value) const; + // Creates a new Literal object with its values havings the primitive_type + // type, and with dimensions defined by the dimensions parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr CreateFromDimensions( + PrimitiveType primitive_type, + tensorflow::gtl::ArraySlice dimensions); - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular floating-point number. - // - // If the literal is not a floating-point value, this always returns false. - // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for values that can be expressed precisely as a float, - // e.g. -0.5. Also if this literal is not array-shaped false is returned. - bool IsAllFloat(float value) const; + // If the given literal's data type is bfloat16, converts it to a float + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. + static std::unique_ptr ConvertBF16ToF32( + const LiteralSlice& bf16_literal); + + // If the given literal's data type is float, converts it to a bfloat16 + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. + static std::unique_ptr ConvertF32ToBF16( + const LiteralSlice& f32_literal); + + // Creates a literal with a new shape with the given new dimensions using the + // data in the given input literal. For reshaping purposes the (flat) data + // buffer of the input literal is assumed to have the given minor_to_major + // layout order. + static std::unique_ptr ReshapeSlice( + tensorflow::gtl::ArraySlice new_dimensions, + tensorflow::gtl::ArraySlice minor_to_major, + const LiteralSlice& literal); + + // Creates a literal with the supplied shape, and uses the provided value + // generator to populate the literal's values. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, + const std::function)>& generator); + + // Creates a literal with the supplied shape, and initializes the literal + // values using a normal distribution with given mean and stddev standard + // deviation, and using the engine as entropy generator. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, typename E, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev); + + // Creates a literal with the supplied shape, and initializes the literal + // values using a normal distribution with given mean and stddev standard + // deviation. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, T mean, T stddev); - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular complex number. - // - // If the literal is not a complex value, this always returns false. - // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for complex values that can be expressed precisely as - // float pairs e.g. (-0.5, 1.0). // - // This literal must have a dense layout. - bool IsAllComplex(complex64 value) const; + // End of factory methods. - // Literal consists entirely of the first element of the literal. - bool IsAllFirst() const; + // Returns a multi-dimensional index as a string. For example: '{7, 8}' will + // be returned for a 2-dimensional index with dimension 0 index equal to 7, + // dimension 1 equal to 8. + static string MultiIndexAsString( + tensorflow::gtl::ArraySlice multi_index); - // Returns whether this literal is zero at the specified index. This literal - // must be an array with a dense layout. - bool IsZero(tensorflow::gtl::ArraySlice indices) const; + private: + // Recursively sets the subshapes and buffers of all subpieces rooted at + // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in + // the shape. + void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); - // Return the count of the elements in the array at the given shape index in - // this literal. - int64 element_count(const ShapeIndex& index = {}) const { - return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); + // Returns the piece at the given ShapeIndex. + Piece& piece(const ShapeIndex& shape_index) { + return const_cast(LiteralBase::piece(shape_index)); } - // Return the count of the elements in the sparse array at the given shape - // index in this literal, which will be no larger than - // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). - int64 sparse_element_count() const; - - // Compute a hash for this literal. This literal must not be a sparse tensor - // or a tuple containing a sparse tensor. - size_t Hash() const; + Piece& root_piece() const override { return *root_piece_; }; - protected: // Internal template helper for the Literal::CopySliceFrom(), matching its // arguments one by one. template - Status CopySliceFromInternal(const Literal& src_literal, + Status CopySliceFromInternal(const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size); @@ -698,162 +1030,69 @@ class Literal { int64 minor_loop_size = 1; }; - // A data structure representing a subshape at a particular ShapeIndex within - // the literal. For array-shaped ShapeIndexes, this data structure holds the - // pointer to the memory allocated for the array data. - class Piece { - public: - // Return the buffer holding the array data for this piece as an array - // slice. This piece must be array-shaped. - template - tensorflow::gtl::ArraySlice data() const; - template - tensorflow::gtl::MutableArraySlice data(); - - // Return the buffer holding the array data for this piece as a void*. This - // piece must be array-shaped. - void* untyped_data(); - const void* untyped_data() const; - - // Gets or sets an element in the array at the given index. The multi_index - // is CHECKed against the dimension sizes of the array. This piece must be - // array-shaped. - template - NativeT Get(tensorflow::gtl::ArraySlice index) const; - template - void Set(tensorflow::gtl::ArraySlice index, NativeT value); - - // Gets/sets the buffer holding the array data. - char* buffer() const { return buffer_; } - void set_buffer(char* buffer) { buffer_ = buffer; } - - // The array of multi-indices that provide the locations of non-zero - // elements in a sparse array. Only used if - // LayoutUtil::IsSparseArray(shape()) is true. - SparseIndexArray* sparse_indices() const { return sparse_indices_; } - void set_sparse_indices(SparseIndexArray* sparse_indices) { - sparse_indices_ = sparse_indices; - } - - // Gets or sets the subshape of this piece. This reference points to a - // subshape within the shape in the containing Literal (Literal::shape_). - const Shape& subshape() const { return *subshape_; } - void set_subshape(const Shape* subshape) { subshape_ = subshape; } - - // Returns the size in bytes of the buffer holding the array data. - int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } - - // Returns the number of elements in this piece's array. - int64 element_count() const { - // If this is a sparse array, use the number of elements represented by - // the indices in the associated SparseIndexArray. - return LayoutUtil::IsSparseArray(subshape()) - ? sparse_indices()->index_count() - : ShapeUtil::ElementsIn(subshape()); - } - - // Copy the data from 'src' into this piece's buffer. Shapes of this piece - // and src must be compatible. - Status CopyFrom(const Piece& src); - - // Returns true if this piece and 'other' contain the same data. This piece - // and 'other' must be array-shaped and compatible. - bool EqualElements(const Piece& other) const; - - // Writes the shape and data (if array-shaped) into the given proto. - void WriteToProto(LiteralProto* proto) const; - - // Copies the data from the given proto into this piece. The shape of this - // piece must be equal (not just compatible) to the shape of the proto. - Status CopyFromProto(const LiteralProto& proto); - - // Sorts the elements in a sparse array. - void SortSparseElements(); - - private: - // Recursive helper for EqualElements. - template - bool EqualElementsInternal(const Piece& other, - std::vector* multi_index) const; - - // Helper for SortSparseElements that has the element type as a template - // parameter. - template - void SortSparseElementsInternal(); - - // For array-shaped pieces, this is the buffer holding the literal data. - char* buffer_ = nullptr; - - // For sparse arrays, this is the array of indices. - SparseIndexArray* sparse_indices_ = nullptr; - - // The shape of piece. This points into the shape of the containing Literal - // (Literal::shape_). - const Shape* subshape_ = nullptr; - }; - - // Returns the piece at the given ShapeIndex. - Piece& piece(const ShapeIndex& shape_index) { - return *pieces_.mutable_element(shape_index); - } - const Piece& piece(const ShapeIndex& shape_index) const { - return pieces_.element(shape_index); - } - - // Returns the piece at the root of the shape (empty ShapeIndex). - Piece& root_piece() { return piece({}); } - const Piece& root_piece() const { return piece({}); } + // Literal class always owns the shape. The parent class borrows this shape. + std::unique_ptr shape_; - // Deallocate the buffers held by this literal (if the literal owns the - // buffer). - void DeallocateBuffers(); + Piece* root_piece_ = nullptr; // Implementation details shared between Populate() and PopulateParallel() template Status PopulateInternal(const FnType& generator, bool parallel); - Shape shape_; - ShapeTree pieces_; - - // Whether the buffers held in pieces_ are owned by this Literal. - bool owns_buffers_; - - // LiteralView must access and manipulate Pieces of other Literals. - friend class LiteralView; -}; // namespace xla + // Deallocate the buffers held by this literal. + void DeallocateBuffers(); + friend class LiteralBase; +}; std::ostream& operator<<(std::ostream& out, const Literal& literal); -// A read-only view of a Literal. A LiteralView contains pointers to buffers -// owned by the viewed Literal. -// -// TODO(b/71550060): Replace LiteralView with Literal slice classes (immutable -// and mutable) similar to (Mutable)ArraySlice. -class LiteralView : public Literal { +// A read-only view of a Literal. A LiteralSlice contains pointers to shape and +// literal buffers always owned by others. +class LiteralSlice : public LiteralBase { public: - // Create and return a view of the given literal rooted at the given shape - // index within the given literal. A factory is used rather than a public - // constructor because only const LiteralViews are supported. It's still - // possible to create non-const LiteralViews via the copy constructors, but - // the factory method makes it a bit less likely. Implementing literal slices - // will fix this undesirable situation (b/71550060). - static const LiteralView Create(const Literal& literal, - const ShapeIndex& view_root = {}); + LiteralSlice() : LiteralBase() {} + + // Implicit conversion constructors. + LiteralSlice(const LiteralBase& literal); + LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root); - LiteralView(const LiteralView& other); - LiteralView& operator=(const LiteralView& other); + private: + const Piece& root_piece() const override { return *root_piece_; }; + + const Piece* root_piece_; // Not owned. +}; - virtual ~LiteralView(); +// A read-only Literal where the underlying buffers are never owned by this +// class. +class BorrowingLiteral : public LiteralBase { + public: + BorrowingLiteral() : LiteralBase() {} + + // 'src_buf_ptr' is not owned by this class and must outlive the + // lifetime of this class. It points to an appropirately sized buffer with + // data interpretered as indicated by 'shape'. + // This constructor is only used for array shapes. + BorrowingLiteral(const char* src_buf_ptr, const Shape& shape); + // Similar as above, except to be used for constructing non-nested tuples. + BorrowingLiteral(tensorflow::gtl::ArraySlice src_buf_ptrs, + const Shape& shape); + // TODO(b/79707221): adding constructors for nested tuples as well. private: - LiteralView(const Literal& literal, const ShapeIndex& view_root); + // Recursively builds the subtree for the given piece and sets the subshapes + // of the given piece with the given shape. + void BuildPieceSubtree(const Shape& shape, Piece* piece); - // Helper for the copy constructor and copy assignment operator. - void CopyFrom(const LiteralView& other); + // Accessor for the root piece of this literal. + const Piece& root_piece() const override { return root_piece_; }; + Piece root_piece_; + + // Shape of this literal. + const Shape shape_; }; template -tensorflow::gtl::ArraySlice Literal::Piece::data() const { +tensorflow::gtl::ArraySlice LiteralBase::Piece::data() const { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); CHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) @@ -866,7 +1105,7 @@ tensorflow::gtl::ArraySlice Literal::Piece::data() const { } template -tensorflow::gtl::MutableArraySlice Literal::Piece::data() { +tensorflow::gtl::MutableArraySlice LiteralBase::Piece::data() { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); CHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) @@ -879,7 +1118,7 @@ tensorflow::gtl::MutableArraySlice Literal::Piece::data() { } template -NativeT Literal::Piece::Get( +NativeT LiteralBase::Piece::Get( tensorflow::gtl::ArraySlice multi_index) const { CHECK(LayoutUtil::IsDenseArray(subshape())); return data()[IndexUtil::MultidimensionalIndexToLinearIndex( @@ -887,15 +1126,15 @@ NativeT Literal::Piece::Get( } template -void Literal::Piece::Set(tensorflow::gtl::ArraySlice multi_index, - NativeT value) { +void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, + NativeT value) { CHECK(LayoutUtil::IsDenseArray(subshape())); data()[IndexUtil::MultidimensionalIndexToLinearIndex( subshape(), multi_index)] = value; } template -tensorflow::gtl::ArraySlice Literal::data( +tensorflow::gtl::ArraySlice LiteralBase::data( const ShapeIndex& shape_index) const { return piece(shape_index).data(); } @@ -907,13 +1146,13 @@ tensorflow::gtl::MutableArraySlice Literal::data( } template -inline NativeT Literal::Get(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const { +inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { return piece(shape_index).Get(multi_index); } template -inline NativeT Literal::Get( +inline NativeT LiteralBase::Get( tensorflow::gtl::ArraySlice multi_index) const { return root_piece().Get(multi_index); } @@ -1160,13 +1399,13 @@ template } template -NativeT Literal::GetFirstElement() const { +NativeT LiteralBase::GetFirstElement() const { return data().at(0); } template -NativeT Literal::GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index) const { +NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index) const { CHECK( LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index))); return data(shape_index)[sparse_element_number]; @@ -1199,7 +1438,7 @@ template } template -void Literal::EachCell( +void LiteralBase::EachCell( std::function indices, NativeT value)> per_cell) const { @@ -1375,7 +1614,7 @@ template } template -std::unique_ptr Literal::Replicate(int64 times) const { +std::unique_ptr LiteralBase::Replicate(int64 times) const { DimensionVector bounds = {times}; bounds.reserve(shape().dimensions_size() + 1); for (int64 bound : shape().dimensions()) { @@ -1410,6 +1649,38 @@ std::unique_ptr Literal::Replicate(int64 times) const { return literal; } +template +/* static */ StatusOr> Literal::CreateRandomLiteral( + const Shape& shape, + const std::function)>& generator) { + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + TF_RET_CHECK(shape.element_type() == type); + auto literal = MakeUnique(shape); + TF_RETURN_IF_ERROR(literal.get()->Populate( + [&](tensorflow::gtl::ArraySlice indexes) { + return generator(indexes); + })); + return std::move(literal); +} + +template +/* static */ StatusOr> Literal::CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev) { + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + std::normal_distribution generator(mean, stddev); + return CreateRandomLiteral( + shape, [&](tensorflow::gtl::ArraySlice /*indexes*/) { + return generator(*engine); + }); +} + +template +/* static */ StatusOr> Literal::CreateRandomLiteral( + const Shape& shape, T mean, T stddev) { + std::minstd_rand0 engine; + return CreateRandomLiteral(shape, &engine, mean, stddev); +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_ diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 61046784e05623cd3117c24ecc6d6c474739bbd5..77f979a0d701f09162e112b69f6128008872aa18 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/layout_util.h" @@ -974,7 +975,7 @@ TEST_F(LiteralUtilTest, CopyFromTuples) { Literal::CreateR1({2.0, 4.0}).get(), &nil_literal}); - EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0})); + EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); EXPECT_EQ(nested_tuple->Get({}, {1, 0}), 42); EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 23.0); EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 44.0); @@ -985,7 +986,7 @@ TEST_F(LiteralUtilTest, CopyFromTuples) { /*src_shape_index=*/{})); // The matrix element should be unchanged. - EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0})); + EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); // The tuple element should have been copied from 'tuple'. EXPECT_EQ(nested_tuple->Get({}, {1, 0}), -5); @@ -1065,7 +1066,7 @@ TEST_F(LiteralUtilTest, Populate) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = Literal::CreateFromShape(shape); + auto literal = MakeUnique(shape); auto generator = [&](ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. @@ -1107,7 +1108,7 @@ TEST_F(LiteralUtilTest, PopulateParallel) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = Literal::CreateFromShape(shape); + auto literal = MakeUnique(shape); auto generator = [&](ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. @@ -1373,36 +1374,36 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { ASSERT_EQ(h1, r[3]); } -TEST_F(LiteralUtilTest, LiteralViewTest) { +TEST_F(LiteralUtilTest, LiteralSliceTest) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); Literal nil(ShapeUtil::MakeNil()); - EXPECT_EQ(LiteralView::Create(*scalar, {}), *scalar); - EXPECT_EQ(LiteralView::Create(*matrix, {}), *matrix); - EXPECT_EQ(LiteralView::Create(*tuple, {}), *tuple); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {}), *nested_tuple); - EXPECT_EQ(LiteralView::Create(nil, {}), nil); + EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar); + EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix); + EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple); + EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple); + EXPECT_EQ(LiteralSlice(nil, {}), nil); - EXPECT_EQ(LiteralView::Create(*tuple, {0}), *scalar); - EXPECT_EQ(LiteralView::Create(*tuple, {1}), *matrix); + EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar); + EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {0}), *tuple); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 0}), *scalar); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 1}), *matrix); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {1}), *scalar); + EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple); + EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar); + EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix); + EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar); } -TEST_F(LiteralUtilTest, MutatingLiteralView) { +TEST_F(LiteralUtilTest, MutatingLiteralSlice) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); // Verify that changing the underlying data beneath the view changes the // data of the view itself. - const auto nested_tuple_view = LiteralView::Create(*nested_tuple); + const auto nested_tuple_view = LiteralSlice(*nested_tuple); EXPECT_EQ( nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), 1.0f); @@ -1418,19 +1419,57 @@ TEST_F(LiteralUtilTest, MutatingLiteralView) { 555.0f); } -TEST_F(LiteralUtilTest, LiteralViewOfALiteralView) { +TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); - const auto nested_tuple_view = LiteralView::Create(*nested_tuple); - const auto tuple_view = - LiteralView::Create(nested_tuple_view, /*view_root=*/{0}); - const auto matrix_view = LiteralView::Create(tuple_view, /*view_root=*/{1}); + const auto nested_tuple_view = LiteralSlice(*nested_tuple); + const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0}); + const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1}); EXPECT_EQ(matrix_view, *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); } +TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) { + std::vector int64_values = {1, 2, 3}; + const Shape literal_shape = ShapeUtil::MakeShape(S64, {3}); + + BorrowingLiteral literal(reinterpret_cast(int64_values.data()), + literal_shape); + + EXPECT_EQ(literal.Get({0}), 1); + EXPECT_EQ(literal.Get({1}), 2); + EXPECT_EQ(literal.Get({2}), 3); +} + +TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrsTest) { + std::vector one_two_three = {1, 2, 3}; + const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3}); + + std::vector hundred = {100}; + const Shape hundred_shape = ShapeUtil::MakeShape(S64, {1}); + + std::vector src_buf_ptrs; + src_buf_ptrs.emplace_back( + reinterpret_cast(one_two_three.data())); + src_buf_ptrs.emplace_back(reinterpret_cast(hundred.data())); + auto literal_tuple = BorrowingLiteral( + src_buf_ptrs, + ShapeUtil::MakeTupleShape({one_two_three_shape, hundred_shape})); + + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{0}, /*shape_index=*/{0}), + 1); + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{0}, /*shape_index=*/{1}), + 100); + + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{1}, /*shape_index=*/{0}), + 2); + + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{2}, /*shape_index=*/{0}), + 3); +} + TEST_F(LiteralUtilTest, LiteralMove) { std::unique_ptr matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); @@ -1533,11 +1572,11 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { EXPECT_EQ(literal.Get({1, 1}), 4.0); } -TEST_F(LiteralUtilTest, LiteralViewCopy) { +TEST_F(LiteralUtilTest, LiteralSliceCopy) { std::unique_ptr matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - const auto matrix_view = LiteralView::Create(*matrix); - LiteralView matrix_view_copy(matrix_view); + const auto matrix_view = LiteralSlice(*matrix); + LiteralSlice matrix_view_copy(matrix_view); EXPECT_EQ(matrix_view_copy.Get({0, 0}), 1.0); EXPECT_EQ(matrix_view_copy.Get({0, 1}), 2.0); diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h index 8db8c6f3de84a6c46625eadbb6b0f83d2262e5f7..3c74e070da529b7f1431e01fbaf31932f582db44 100644 --- a/tensorflow/compiler/xla/map_util.h +++ b/tensorflow/compiler/xla/map_util.h @@ -86,11 +86,10 @@ const typename Collection::value_type::second_type& FindOrDefault( // Inserts the key-value pair into the collection. Dies if key was already // present. -template -void InsertOrDie(Collection* const collection, - const typename Collection::value_type::first_type& key, - const typename Collection::value_type::second_type& data) { - auto p = collection->insert(std::make_pair(key, data)); +template +void InsertOrDie(Collection* const collection, Key&& key, Value&& value) { + auto p = collection->insert( + std::make_pair(std::forward(key), std::forward(value))); CHECK(p.second) << "duplicate key: " << key; } @@ -101,9 +100,10 @@ bool ContainsKey(const Collection& collection, const Key& key) { } // Inserts `value` into `set`. Dies if it was already present. -template -void InsertOrDie(Set* const set, const typename Set::value_type& value) { - CHECK(set->insert(value).second) << "duplicate value: " << value; +template +void InsertOrDie(Set* const set, Value&& value) { + CHECK(set->insert(std::forward(value)).second) + << "duplicate value: " << value; } } // namespace xla diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index ecb87bd8893276fbb9ecffaa0f8a3233d2e0043f..932cce943f7c046a85984e6e5ed6b59dae371473 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -49,9 +49,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 044458164ff89c554262e1bdbbd4dbed120a2ff4..cb4dc1782b680fca1485e883343fbb262b86b1d1 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -15,8 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/python/local_computation_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/platform/default/thread_annotations.h" +#include "tensorflow/core/platform/thread_annotations.h" namespace xla { @@ -248,7 +249,7 @@ LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers( return new LocalShapedBuffer(std::move(result_buffer)); } -LocalComputation::LocalComputation(Computation computation) +LocalComputation::LocalComputation(XlaComputation computation) : computation_(std::move(computation)) {} StatusOr LocalComputation::Compile( @@ -271,7 +272,7 @@ StatusOr LocalComputation::Compile( return new CompiledLocalComputation(std::move(local_executable)); } -const Computation& LocalComputation::computation() const { +const XlaComputation& LocalComputation::computation() const { return computation_; } @@ -281,8 +282,12 @@ StatusOr LocalComputation::GetReturnValueShape() const { return std::move(*program_shape.mutable_result()); } +LocalOp::LocalOp(const XlaOp& op) : op_(op) {} + +const XlaOp& LocalOp::op() const { return op_; } + LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) - : builder_(GetOrCreateLocalClient(), computation_name) {} + : builder_(computation_name) {} void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { builder_.SetOpMetadata(metadata); @@ -291,19 +296,21 @@ void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } StatusOr LocalComputationBuilder::Build() { - TF_ASSIGN_OR_RETURN(Computation computation, builder_.Build()); + TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build()); return new LocalComputation(std::move(computation)); } -ComputationDataHandle LocalComputationBuilder::Parameter(int64 parameter_number, - const Shape& shape, - const string& name) { +LocalOp LocalComputationBuilder::Parameter(int64 parameter_number, + const Shape& shape, + const string& name) { return builder_.Parameter(parameter_number, shape, name); } std::unique_ptr LocalComputationBuilder::GetShape( - const ComputationDataHandle& operand) { - return builder_.GetShape(operand).ConsumeValueOrDie(); + const LocalOp& operand) { + auto result = MakeUnique(); + *result = builder_.GetShape(operand.op()).ValueOrDie(); + return result; } StatusOr LocalComputationBuilder::GetReturnValueShape() { @@ -311,222 +318,236 @@ StatusOr LocalComputationBuilder::GetReturnValueShape() { return program_shape.result(); } -ComputationDataHandle LocalComputationBuilder::Infeed(const Shape& shape) { +LocalOp LocalComputationBuilder::Infeed(const Shape& shape) { return builder_.Infeed(shape); } -void LocalComputationBuilder::Outfeed(const ComputationDataHandle& operand, +void LocalComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape, const string& outfeed_config) { - builder_.Outfeed(operand, shape, outfeed_config); + builder_.Outfeed(operand.op(), shape, outfeed_config); } -ComputationDataHandle LocalComputationBuilder::ConstantLiteral( - const Literal& literal) { +LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { return builder_.ConstantLiteral(literal); } -ComputationDataHandle LocalComputationBuilder::Broadcast( - const ComputationDataHandle& operand, +LocalOp LocalComputationBuilder::Broadcast( + const LocalOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { - return builder_.Broadcast(operand, broadcast_sizes); + return builder_.Broadcast(operand.op(), broadcast_sizes); } -ComputationDataHandle LocalComputationBuilder::Pad( - const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config) { - return builder_.Pad(operand, padding_value, padding_config); +LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, + const LocalOp& padding_value, + const PaddingConfig& padding_config) { + return builder_.Pad(operand.op(), padding_value.op(), padding_config); } -ComputationDataHandle LocalComputationBuilder::Reshape( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, +LocalOp LocalComputationBuilder::Reshape( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { - return builder_.Reshape(operand, dimensions, new_sizes); + return builder_.Reshape(operand.op(), dimensions, new_sizes); } -ComputationDataHandle LocalComputationBuilder::Collapse( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - return builder_.Collapse(operand, dimensions); +LocalOp LocalComputationBuilder::Collapse( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { + return builder_.Collapse(operand.op(), dimensions); } -ComputationDataHandle LocalComputationBuilder::CrossReplicaSum( - const ComputationDataHandle& operand) { - return builder_.CrossReplicaSum(operand); +LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) { + return builder_.CrossReplicaSum(operand.op()); } -ComputationDataHandle LocalComputationBuilder::Slice( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, +LocalOp LocalComputationBuilder::Slice( + const LocalOp& operand, tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - return builder_.Slice(operand, start_indices, limit_indices, strides); + return builder_.Slice(operand.op(), start_indices, limit_indices, strides); } -ComputationDataHandle LocalComputationBuilder::SliceInDim( - const ComputationDataHandle& operand, int64 start_index, int64 limit_index, - int64 stride, int64 dimno) { - return builder_.SliceInDim(operand, start_index, limit_index, stride, dimno); +LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, + int64 start_index, + int64 limit_index, int64 stride, + int64 dimno) { + return builder_.SliceInDim(operand.op(), start_index, limit_index, stride, + dimno); } -ComputationDataHandle LocalComputationBuilder::DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, +LocalOp LocalComputationBuilder::DynamicSlice( + const LocalOp& operand, const LocalOp& start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - return builder_.DynamicSlice(operand, start_indices, slice_sizes); + return builder_.DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } -ComputationDataHandle LocalComputationBuilder::DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices) { - return builder_.DynamicUpdateSlice(operand, update, start_indices); +LocalOp LocalComputationBuilder::DynamicUpdateSlice( + const LocalOp& operand, const LocalOp& update, + const LocalOp& start_indices) { + return builder_.DynamicUpdateSlice(operand.op(), update.op(), + start_indices.op()); } -ComputationDataHandle LocalComputationBuilder::ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension) { - return builder_.ConcatInDim(operands, dimension); +LocalOp LocalComputationBuilder::ConcatInDim( + tensorflow::gtl::ArraySlice operands, int64 dimension) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return builder_.ConcatInDim(xla_ops, dimension); } -ComputationDataHandle -LocalComputationBuilder::SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const LocalComputation& select, +LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const LocalComputation& select, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const LocalComputation& scatter) { + const LocalOp& source, const LocalOp& init_value, + const LocalComputation& scatter) { return builder_.SelectAndScatterWithGeneralPadding( - operand, select.computation(), window_dimensions, window_strides, padding, - source, init_value, scatter.computation()); + operand.op(), select.computation(), window_dimensions, window_strides, + padding, source.op(), init_value.op(), scatter.computation()); } -ComputationDataHandle LocalComputationBuilder::Tuple( - tensorflow::gtl::ArraySlice elements) { - return builder_.Tuple(elements); +LocalOp LocalComputationBuilder::Tuple( + tensorflow::gtl::ArraySlice elements) { + std::vector xla_ops; + xla_ops.reserve(elements.size()); + for (const auto& op : elements) { + xla_ops.push_back(op.op()); + } + + return builder_.Tuple(xla_ops); } -ComputationDataHandle LocalComputationBuilder::GetTupleElement( - const ComputationDataHandle& tuple_data, int64 index) { - return builder_.GetTupleElement(tuple_data, index); +LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data, + int64 index) { + return builder_.GetTupleElement(tuple_data.op(), index); } -ComputationDataHandle LocalComputationBuilder::Dot( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - return builder_.Dot(lhs, rhs); +LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { + return builder_.Dot(lhs.op(), rhs.op()); } -ComputationDataHandle LocalComputationBuilder::DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, +LocalOp LocalComputationBuilder::DotGeneral( + const LocalOp& lhs, const LocalOp& rhs, const DotDimensionNumbers& dimension_numbers) { - return builder_.DotGeneral(lhs, rhs, dimension_numbers); + return builder_.DotGeneral(lhs.op(), rhs.op(), dimension_numbers); } -ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, +LocalOp LocalComputationBuilder::ConvGeneralDilated( + const LocalOp& lhs, const LocalOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers) { - return builder_.ConvGeneralDilated(lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation, + return builder_.ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, + padding, lhs_dilation, rhs_dilation, dimension_numbers); } -ComputationDataHandle LocalComputationBuilder::ConvertElementType( - const ComputationDataHandle& operand, PrimitiveType new_element_type) { - return builder_.ConvertElementType(operand, new_element_type); +LocalOp LocalComputationBuilder::ConvertElementType( + const LocalOp& operand, PrimitiveType new_element_type) { + return builder_.ConvertElementType(operand.op(), new_element_type); } -ComputationDataHandle LocalComputationBuilder::Call( +LocalOp LocalComputationBuilder::Call( const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands) { - return builder_.Call(local_computation.computation(), operands); + tensorflow::gtl::ArraySlice operands) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return builder_.Call(local_computation.computation(), xla_ops); } -ComputationDataHandle LocalComputationBuilder::Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation) { - return builder_.Transpose(operand, permutation); +LocalOp LocalComputationBuilder::Transpose( + const LocalOp& operand, tensorflow::gtl::ArraySlice permutation) { + return builder_.Transpose(operand.op(), permutation); } -ComputationDataHandle LocalComputationBuilder::Rev( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - return builder_.Rev(operand, dimensions); +LocalOp LocalComputationBuilder::Rev( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { + return builder_.Rev(operand.op(), dimensions); } -ComputationDataHandle LocalComputationBuilder::Map( - tensorflow::gtl::ArraySlice operands, +LocalOp LocalComputationBuilder::Map( + tensorflow::gtl::ArraySlice operands, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { - return builder_.Map(operands, local_computation.computation(), dimensions, - static_operands); + tensorflow::gtl::ArraySlice static_operands) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + + std::vector static_xla_ops; + static_xla_ops.reserve(static_operands.size()); + for (const auto& op : static_operands) { + static_xla_ops.push_back(op.op()); + } + + return builder_.Map(xla_ops, local_computation.computation(), dimensions, + static_xla_ops); } -ComputationDataHandle LocalComputationBuilder::Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, +LocalOp LocalComputationBuilder::Reduce( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice dimensions_to_reduce) { - return builder_.Reduce(operand, init_value, local_computation.computation(), - dimensions_to_reduce); + return builder_.Reduce(operand.op(), init_value.op(), + local_computation.computation(), dimensions_to_reduce); } -ComputationDataHandle LocalComputationBuilder::ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, +LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding) { return builder_.ReduceWindowWithGeneralPadding( - operand, init_value, local_computation.computation(), window_dimensions, - window_strides, padding); + operand.op(), init_value.op(), local_computation.computation(), + window_dimensions, window_strides, padding); } -ComputationDataHandle LocalComputationBuilder::RngNormal( - const ComputationDataHandle& mu, const ComputationDataHandle& sigma, - const Shape& shape) { - return builder_.RngNormal(mu, sigma, shape); +LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, + const LocalOp& sigma, + const Shape& shape) { + return builder_.RngNormal(mu.op(), sigma.op(), shape); } -ComputationDataHandle LocalComputationBuilder::RngUniform( - const ComputationDataHandle& a, const ComputationDataHandle& b, - const Shape& shape) { - return builder_.RngUniform(a, b, shape); +LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, + const Shape& shape) { + return builder_.RngUniform(a.op(), b.op(), shape); } -ComputationDataHandle LocalComputationBuilder::While( - const LocalComputation& condition, const LocalComputation& body, - const ComputationDataHandle& init) { - return builder_.While(condition.computation(), body.computation(), init); +LocalOp LocalComputationBuilder::While(const LocalComputation& condition, + const LocalComputation& body, + const LocalOp& init) { + return builder_.While(condition.computation(), body.computation(), init.op()); } -ComputationDataHandle LocalComputationBuilder::Conditional( - const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const LocalComputation& true_computation, - const ComputationDataHandle& false_operand, +LocalOp LocalComputationBuilder::Conditional( + const LocalOp& predicate, const LocalOp& true_operand, + const LocalComputation& true_computation, const LocalOp& false_operand, const LocalComputation& false_computation) { - return builder_.Conditional(predicate, true_operand, - true_computation.computation(), false_operand, - false_computation.computation()); + return builder_.Conditional( + predicate.op(), true_operand.op(), true_computation.computation(), + false_operand.op(), false_computation.computation()); } -StatusOr LocalComputationBuilder::IsConstant( - const ComputationDataHandle& operand, int64 num_parameters) { - return builder_.IsConstant(operand, num_parameters); +StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { + return builder_.IsConstant(operand.op()); } -StatusOr> LocalComputationBuilder::ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice parameters) { - return builder_.ComputeConstant(operand, output_layout, parameters); +StatusOr LocalComputationBuilder::BuildConstantSubGraph( + const LocalOp& operand) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, + builder_.BuildConstantSubGraph(operand.op())); + return new LocalComputation(std::move(computation)); } #define _FORWARD(method_name, return_sig, args_sig, args) \ @@ -534,23 +555,19 @@ StatusOr> LocalComputationBuilder::ComputeConstant( return builder_.method_name args; \ } -#define _FORWARD_UNOP(method_name) \ - _FORWARD(method_name, ComputationDataHandle, \ - (const ComputationDataHandle& operand), (operand)) - -#define _FORWARD_BINOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions), \ - (lhs, rhs, broadcast_dimensions)) - -#define _FORWARD_TRIOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - const ComputationDataHandle& ehs), \ - (lhs, rhs, ehs)) +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, LocalOp, (const LocalOp& operand), (operand.op())) + +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions), \ + (lhs.op(), rhs.op(), broadcast_dimensions)) + +#define _FORWARD_TRIOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs), \ + (lhs.op(), rhs.op(), ehs.op())) _FORWARD_TRIOP(Select) _FORWARD_TRIOP(Clamp) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 5ec097846a59fdc7dd51f04f011225495643930d..a06b85b4ea28c4f386598901138930eaaed12079 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -17,9 +17,10 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -97,25 +98,37 @@ class CompiledLocalComputation { std::unique_ptr executable_; }; -// Wraps a Computation produced by a LocalComputationBuilder. The +// Wraps a XlaComputation produced by a LocalComputationBuilder. The // Compile method compiles the computation to a (local) executable via // the client library's local client. This class is intended to be // made available to Python via SWIG. class LocalComputation { public: - LocalComputation(Computation computation); + LocalComputation(XlaComputation computation); StatusOr Compile( const std::vector& argument_shapes, const ExecutableBuildOptions* build_options); - const Computation& computation() const; + const XlaComputation& computation() const; // Returns the return-value shape for this computation. StatusOr GetReturnValueShape() const; private: - Computation computation_; + XlaComputation computation_; +}; + +// Wraps a XlaOp produced by a LocalComputationBuilder. This class is intended +// to be made available to Python via SWIG. +class LocalOp { + public: + LocalOp(const XlaOp& op); + + const XlaOp& op() const; + + private: + XlaOp op_; }; // Wraps the ComputationBuilder API in order to: @@ -135,166 +148,137 @@ class LocalComputationBuilder { // Returns an owned LocalComputation to the caller on success. StatusOr Build(); - ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape, - const string& name); + LocalOp Parameter(int64 parameter_number, const Shape& shape, + const string& name); - std::unique_ptr GetShape(const ComputationDataHandle& operand); + std::unique_ptr GetShape(const LocalOp& operand); // Returns the shape of the current return value for the computation. StatusOr GetReturnValueShape(); - ComputationDataHandle Infeed(const Shape& shape); + LocalOp Infeed(const Shape& shape); - void Outfeed(const ComputationDataHandle& operand, const Shape& shape, + void Outfeed(const LocalOp& operand, const Shape& shape, const string& outfeed_config); - ComputationDataHandle ConstantLiteral(const Literal& literal); + LocalOp ConstantLiteral(const Literal& literal); - ComputationDataHandle Broadcast( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); + LocalOp Broadcast(const LocalOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); - ComputationDataHandle Pad(const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config); + LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value, + const PaddingConfig& padding_config); - ComputationDataHandle Reshape(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + LocalOp Reshape(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); - ComputationDataHandle Collapse(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Collapse(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions); - ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand); + LocalOp CrossReplicaSum(const LocalOp& operand); - ComputationDataHandle Slice(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + LocalOp Slice(const LocalOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); - ComputationDataHandle SliceInDim(const ComputationDataHandle& operand, - int64 start_index, int64 limit_index, - int64 stride, int64 dimno); + LocalOp SliceInDim(const LocalOp& operand, int64 start_index, + int64 limit_index, int64 stride, int64 dimno); - ComputationDataHandle DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); - ComputationDataHandle DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices); + LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update, + const LocalOp& start_indices); - ComputationDataHandle ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension); + LocalOp ConcatInDim(tensorflow::gtl::ArraySlice operands, + int64 dimension); - ComputationDataHandle SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const LocalComputation& select, + LocalOp SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const LocalComputation& select, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const LocalComputation& scatter); + const LocalOp& source, const LocalOp& init_value, + const LocalComputation& scatter); - ComputationDataHandle Tuple( - tensorflow::gtl::ArraySlice elements); + LocalOp Tuple(tensorflow::gtl::ArraySlice elements); - ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data, - int64 index); + LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index); - ComputationDataHandle Dot(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs); + LocalOp Dot(const LocalOp& lhs, const LocalOp& rhs); - ComputationDataHandle DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - const DotDimensionNumbers& dimension_numbers); + LocalOp DotGeneral(const LocalOp& lhs, const LocalOp& rhs, + const DotDimensionNumbers& dimension_numbers); - ComputationDataHandle ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + LocalOp ConvGeneralDilated( + const LocalOp& lhs, const LocalOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers); - ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, - PrimitiveType new_element_type); + LocalOp ConvertElementType(const LocalOp& operand, + PrimitiveType new_element_type); - ComputationDataHandle Call( - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands); + LocalOp Call(const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice operands); - ComputationDataHandle Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation); + LocalOp Transpose(const LocalOp& operand, + tensorflow::gtl::ArraySlice permutation); - ComputationDataHandle Rev(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Rev(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions); - ComputationDataHandle Map( - tensorflow::gtl::ArraySlice operands, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands); + LocalOp Map(tensorflow::gtl::ArraySlice operands, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands); - ComputationDataHandle Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); - ComputationDataHandle ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, + LocalOp ReduceWindowWithGeneralPadding( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding); - ComputationDataHandle RngNormal(const ComputationDataHandle& mu, - const ComputationDataHandle& sigma, - const Shape& shape); + LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma, + const Shape& shape); - ComputationDataHandle RngUniform(const ComputationDataHandle& a, - const ComputationDataHandle& b, - const Shape& shape); + LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape); - ComputationDataHandle While(const LocalComputation& condition, - const LocalComputation& body, - const ComputationDataHandle& init); + LocalOp While(const LocalComputation& condition, const LocalComputation& body, + const LocalOp& init); - ComputationDataHandle Conditional(const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const LocalComputation& true_computation, - const ComputationDataHandle& false_operand, - const LocalComputation& false_computation); + LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand, + const LocalComputation& true_computation, + const LocalOp& false_operand, + const LocalComputation& false_computation); - StatusOr IsConstant(const ComputationDataHandle& operand, - int64 num_parameters); + StatusOr IsConstant(const LocalOp& operand); - StatusOr > ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice parameters); + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ return_sig method_name args_sig; -#define _FORWARD_UNOP(method_name) \ - _FORWARD(method_name, ComputationDataHandle, \ - (const ComputationDataHandle& operand)) +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, LocalOp, (const LocalOp& operand)) -#define _FORWARD_BINOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions)) +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions)) -#define _FORWARD_TRIOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - const ComputationDataHandle& ehs)) +#define _FORWARD_TRIOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs)) _FORWARD_TRIOP(Select) _FORWARD_TRIOP(Clamp) @@ -338,7 +322,7 @@ class LocalComputationBuilder { #undef _FORWARD_TRIOP private: - ComputationBuilder builder_; + XlaBuilder builder_; }; // Functions for freeing resources from the Python side. diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index b8cce5a5f7105ef59d6f02ac7c1995bb81df4d58..04c56bbba95fbf3248df6c49700ff563c8b253c0 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -22,9 +22,8 @@ limitations under the License. // // C++ Python // -------------------------------------+--------------------------------------- -// ComputationDataHandle <-> int // ArraySlice <- sequence of int -// ArraySlice <- sequence of int +// ArraySlice <- sequence of LocalOp // Literal <-> (nested tuple of) numpy ndarray // std::vector <- sequence of (nested tuple of) ndarray // Shape -> pair holding (dtype, dimensions) @@ -91,12 +90,9 @@ limitations under the License. // One central reason for the Python-side indirection is that the // Python-side objects produced by the typemaps in this file are // further packaged up by xla_client before being passed on. For -// instance, xla_client wraps the long produced for a C++ -// ComputationDataHandle in a Python ComputationDataHandle proto, -// rather than exposing a raw long outside of the client. Similarly, -// the Python pair produced for a C++ Shape is further wrapped in a -// Python class (xla_client.Shape) so as not to expose the raw pair -// externally. +// instance, the Python pair produced for a C++ Shape is further +// wrapped in a Python class (xla_client.Shape) so as not to expose +// the raw pair externally. // // Other SWIG object wrappers (e.g. of LocalComputation) are further // wrapped by xla_client in order to set up a custom destructor that @@ -124,6 +120,7 @@ using namespace xla; using namespace xla::swig; namespace xla { + namespace swig { bool GetIntAttr(PyObject* o, const char* field, int64* result) { @@ -177,21 +174,6 @@ bool HandleStringAttribute(PyObject* o, tensorflow::ImportNumpy(); %} -// ComputationDataHandle - -%typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) { - const int64 handle = numpy::PyIntOrPyLongToLong($input); - if (handle == -1 && PyErr_Occurred()) { - SWIG_fail; - } - temp.set_handle(handle); - $1 = &temp; -} - -%typemap(out) ComputationDataHandle { - $result = numpy::LongToPyIntOrPyLong($1.handle()); -} - %typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); @@ -301,33 +283,23 @@ tensorflow::ImportNumpy(); $1 = temps; } -// ComputationDataHandle +// ArraySlice -%typemap(in) tensorflow::gtl::ArraySlice - (std::vector temps) { +%typemap(in) tensorflow::gtl::ArraySlice( + std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); SWIG_fail; } const int size = PySequence_Size($input); - temps.resize(size); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); - PyObject* py_int = numpy::PyNumberToPyInt(o); - if (!py_int) { - PyErr_SetString( - PyExc_TypeError, - "Argument sequence element cannot be converted to int"); - SWIG_fail; - } - const int64 handle = numpy::PyIntOrPyLongToLong(py_int); - if (handle == -1 && PyErr_Occurred()) { - Py_DECREF(py_int); - Py_DECREF(o); + LocalOp* op; + if ((SWIG_ConvertPtr(o, (void**)&op, $descriptor(xla::swig::LocalOp*), + SWIG_POINTER_EXCEPTION)) == -1) { SWIG_fail; } - temps[i].set_handle(handle); - Py_DECREF(py_int); + temps.push_back(*op); Py_DECREF(o); } $1 = temps; @@ -934,6 +906,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputation; %unignore xla::swig::LocalComputation::Compile; %unignore xla::swig::LocalComputation::GetReturnValueShape; +%unignore xla::swig::LocalOp; %unignore xla::swig::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::Build; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index dc6f5fe5fcc067c99ced01812f9f2388a00766d0..68648a3a176363de69a56ecb8070f82862874e94 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -340,13 +340,13 @@ StatusOr OpMetadataFromPyObject(PyObject* o) { return result; } -PyObject* PyObjectFromXlaLiteral(const Literal& literal) { +PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { if (ShapeUtil::IsTuple(literal.shape())) { int num_elements = ShapeUtil::TupleElementCount(literal.shape()); PyObject* tuple = PyTuple_New(num_elements); for (int i = 0; i < num_elements; i++) { - PyTuple_SET_ITEM( - tuple, i, PyObjectFromXlaLiteral(LiteralView::Create(literal, {i}))); + PyTuple_SET_ITEM(tuple, i, + PyObjectFromXlaLiteral(LiteralSlice(literal, {i}))); } return tuple; } else { @@ -431,7 +431,7 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, return Status::OK(); } -void CopyLiteralToNumpyArray(int np_type, const Literal& literal, +void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, PyArrayObject* py_array) { switch (np_type) { case NPY_BOOL: diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index 9656cb1c31c39dbe54293700c2765d0723255657..64f0aae0f9790f0199ac6cb931a5c9f6dc356f4c 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -74,7 +74,7 @@ StatusOr OpMetadataFromPyObject(PyObject* o); // array data. // // The return value is a new reference. -PyObject* PyObjectFromXlaLiteral(const Literal& literal); +PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal); // Converts a Numpy ndarray or a nested Python tuple thereof to a // corresponding XLA literal. @@ -90,7 +90,7 @@ StatusOr > XlaLiteralFromPyObject(PyObject* o); Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, Literal* literal); -void CopyLiteralToNumpyArray(int np_type, const Literal& literal, +void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, PyArrayObject* py_array); template @@ -101,7 +101,8 @@ void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) { } template -void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) { +void CopyLiteralToNumpyArray(const LiteralSlice& literal, + PyArrayObject* py_array) { NativeT* dest = static_cast(PyArray_DATA(py_array)); auto source = literal.data(); std::copy(source.begin(), source.end(), dest); diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index f6809b6b871d7e246dd43811c7e8c08378d53989..1d5b75d1bee2dcee3e448d0bcb72103b539efac6 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -335,20 +335,6 @@ def _wrap_shape(shape_info): return Shape.array_shape(dtype, dims) -def _wrap_data_handle(handle): - cdh = xla_data_pb2.ComputationDataHandle() - cdh.handle = handle - return cdh - - -def _unwrap_data_handle(handle_proto): - return handle_proto.handle - - -def _unwrap_data_handles(handle_protos): - return [_unwrap_data_handle(cdh) for cdh in handle_protos] - - def require_numpy_array_layout(value): if isinstance(value, tuple): return tuple(require_numpy_array_layout(x) for x in value) @@ -535,9 +521,9 @@ class ComputationBuilder(object): queue for subsequent use in the computation. Returns: - A ComputationDataHandle message. + A LocalOp. """ - return _wrap_data_handle(self._client.Infeed(shape)) + return self._client.Infeed(shape) def Outfeed(self, operand): """Enqueues an outfeed op onto the computation. @@ -545,9 +531,7 @@ class ComputationBuilder(object): Outfeed operations enqueue data, using the given operand, onto the XLA outfeed queue for subsequent dequeue via the client API. """ - self._client.Outfeed( - _unwrap_data_handle(operand), self.GetShape(operand), - ''.encode('utf-8')) + self._client.Outfeed(operand, self.GetShape(operand), ''.encode('utf-8')) def Constant(self, value): """Enqueues a constant op onto the computation. @@ -557,10 +541,10 @@ class ComputationBuilder(object): to one of the supported types. Returns: - A ComputationDataHandle message. + A LocalOp. """ value = require_numpy_array_layout(value) - return _wrap_data_handle(self._client.ConstantLiteral(value)) + return self._client.ConstantLiteral(value) def ConstantF32Scalar(self, value): """Convenience method to enqueue a scalar F32 constant op. @@ -569,7 +553,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.float32)) @@ -580,7 +564,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.float64)) @@ -591,7 +575,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.int32)) @@ -602,7 +586,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.int64)) @@ -613,7 +597,7 @@ class ComputationBuilder(object): value: a boolean value. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.bool)) @@ -629,15 +613,14 @@ class ComputationBuilder(object): parameters, use it for *all* parameters to avoid clashes. Returns: - A ComputationDataHandle message. + A LocalOp. """ if name is None: name = '' if parameter_num is None: parameter_num = next(self._parameter_numbering) - return _wrap_data_handle( - self._client.Parameter(parameter_num, shape, name.encode('utf8'))) + return self._client.Parameter(parameter_num, shape, name.encode('utf8')) def ParameterFromNumpy(self, value, name=None, parameter_num=None): """Enqueues a Parameter op onto the computation. @@ -649,7 +632,7 @@ class ComputationBuilder(object): parameter_num: as in ParameterWithShape. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.ParameterWithShape( Shape.from_pyval(value), name=name, parameter_num=parameter_num) @@ -658,14 +641,13 @@ class ComputationBuilder(object): """Enqueues a broadcast operation onto the computation. Args: - operand: the operand ComputationDataHandle to broadcast. + operand: the operand LocalOp to broadcast. sizes: an iterable of broadcast sizes. Returns: - A ComputationDataHandle representing the added broadcast op. + A LocalOp representing the added broadcast op. """ - return _wrap_data_handle( - self._client.Broadcast(_unwrap_data_handle(operand), sizes)) + return self._client.Broadcast(operand, sizes) def Concatenate(self, operands, dimension): """Enqueues a concatenate operation onto the computation. @@ -675,10 +657,9 @@ class ComputationBuilder(object): dimension: the dimension in which to perform the concatenation. Returns: - A ComputationDataHandle representing the added concatenate op. + A LocalOp representing the added concatenate op. """ - return _wrap_data_handle( - self._client.ConcatInDim(_unwrap_data_handles(operands), dimension)) + return self._client.ConcatInDim(operands, dimension) def ConvertElementType(self, operand, new_element_type): """Enqueues an element type conversion operation onto the computation. @@ -688,14 +669,12 @@ class ComputationBuilder(object): new_element_type: the target primitive type. Returns: - A ComputationDataHandle representing the added conversion op. + A LocalOp representing the added conversion op. """ - return _wrap_data_handle( - self._client.ConvertElementType( - _unwrap_data_handle(operand), new_element_type)) + return self._client.ConvertElementType(operand, new_element_type) def GetShape(self, operand): - return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand))) + return _wrap_shape(self._client.GetShape(operand)) def GetReturnValueShape(self): return _wrap_shape(self._client.GetReturnValueShape()) @@ -707,40 +686,35 @@ class ComputationBuilder(object): """Enqueues a Pad operation onto the computation. Args: - operand: ComputationDataHandle representing the array to pad. - padding_value: ComputationDataHandle representing the scalar pad value. + operand: LocalOp representing the array to pad. + padding_value: LocalOp representing the scalar pad value. padding_config: either an xla_data_pb2.PaddingConfig or a list of integer triples (edge_padding_low, edge_padding_high, interior_padding) representing the configuration of the padding operation. Returns: - A ComputationDataHandle representing the added Pad op. + A LocalOp representing the added Pad op. """ if not isinstance(padding_config, xla_data_pb2.PaddingConfig): padding_config = GetPaddingConfigFromTriples(padding_config) - return _wrap_data_handle( - self._client.Pad(_unwrap_data_handle(operand), - _unwrap_data_handle(padding_value), - padding_config)) + return self._client.Pad(operand, padding_value, padding_config) def Reshape(self, operand, dimensions, new_sizes): """Enqueues a reshape op onto the computation. Args: - operand: ComputationDataHandle representing the array to be reshaped. + operand: LocalOp representing the array to be reshaped. dimensions: sequence of integers encoding the order in which dimensions are collapsed or None, in which case dimensions are flattened in order. new_sizes: sequence of integers encoding the new dimension sizes (shape). Returns: - A ComputationDataHandle representing the added Reshape op. + A LocalOp representing the added Reshape op. """ if dimensions is None: ndim = len(self.GetShape(operand).dimensions()) dimensions = tuple(range(ndim)) - return _wrap_data_handle( - self._client.Reshape( - _unwrap_data_handle(operand), dimensions, new_sizes)) + return self._client.Reshape(operand, dimensions, new_sizes) def CrossReplicaSum(self, operand): """CrossReplicaSum op. @@ -749,67 +723,56 @@ class ComputationBuilder(object): operand: the operand to sum across replica instances. Returns: - A ComputationDataHandle that has the sum of the value among all replicas. + A LocalOp that has the sum of the value among all replicas. """ - return _wrap_data_handle( - self._client.CrossReplicaSum(_unwrap_data_handle(operand))) + return self._client.CrossReplicaSum(operand) def Collapse(self, operand, dimensions): """Collapse op.""" - return _wrap_data_handle( - self._client.Collapse(_unwrap_data_handle(operand), dimensions)) + return self._client.Collapse(operand, dimensions) def Trans(self, operand): """Specialized matrix transpose op.""" - return _wrap_data_handle( - self._client.Transpose(_unwrap_data_handle(operand), [1, 0])) + return self._client.Transpose(operand, [1, 0]) def Transpose(self, operand, permutation): """Transpose op.""" - return _wrap_data_handle( - self._client.Transpose(_unwrap_data_handle(operand), permutation)) + return self._client.Transpose(operand, permutation) def Rev(self, operand, dimensions): """Rev op.""" - return _wrap_data_handle( - self._client.Rev(_unwrap_data_handle(operand), dimensions)) + return self._client.Rev(operand, dimensions) def Clamp(self, min, operand, max): # pylint: disable=redefined-builtin """Clamp op.""" - return _wrap_data_handle( - self._client.Clamp(_unwrap_data_handle(min), - _unwrap_data_handle(operand), - _unwrap_data_handle(max))) + return self._client.Clamp(min, operand, max) def SelectAndScatter(self, operand, select, window_dimensions, window_strides, padding, source, init_value, scatter): """Select and scatter op, used by the gradient of ReduceWindow. Args: - operand: ComputationDataHandle for array of dimension N and type T over + operand: LocalOp for array of dimension N and type T over which the windows slide. select: Computation of type (T, T) -> Pred to apply to the elements of each window to indicate which element is selected. window_dimensions: sequence of N integers for dimensions of the window. window_strides: sequence of N integers for the strides of the window. padding: PaddingType representing either 'SAME' or 'VALID ' padding. - source: ComputationDataHandle for array of type T with values to scatter. - init_value: ComputationDataHandle of scalar type T for initial out value. + source: LocalOp for array of type T with values to scatter. + init_value: LocalOp of scalar type T for initial out value. scatter: Computation of type (T, T) -> T to apply to each scatter source element with its destination element. Returns: - A ComputationDataHandle representing the added SelectAndScatter op. + A LocalOp representing the added SelectAndScatter op. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(operand).dimensions(), window_dimensions, window_strides) - return _wrap_data_handle( - self._client.SelectAndScatterWithGeneralPadding( - _unwrap_data_handle(operand), select.c_local_computation, - window_dimensions, window_strides, pads, - _unwrap_data_handle(source), _unwrap_data_handle(init_value), - scatter.c_local_computation)) + return self._client.SelectAndScatterWithGeneralPadding( + operand, select.c_local_computation, window_dimensions, window_strides, + pads, source, init_value, scatter.c_local_computation) def Select(self, pred, on_true, on_false): """Element-wise selection op. @@ -817,17 +780,13 @@ class ComputationBuilder(object): Constructs an output array from elements of two input arrays, based on the values of a predicate array. """ - return _wrap_data_handle( - self._client.Select( - _unwrap_data_handle(pred), - _unwrap_data_handle(on_true), - _unwrap_data_handle(on_false))) + return self._client.Select(pred, on_true, on_false) def Slice(self, operand, start_indices, limit_indices, strides=None): """Enqueues a slice operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. + operand: LocalOp for the N dimensional array to be sliced. start_indices: iterable of N integers containing the starting indices of the slice for each dimension. limit_indices: iterable of N integers containing the ending indices @@ -836,207 +795,177 @@ class ComputationBuilder(object): each dimension. Returns: - A ComputationDataHandle representing the added Slice op. + A LocalOp representing the added Slice op. """ if strides is None: start_indices = list(start_indices) strides = [1] * len(start_indices) - return _wrap_data_handle( - self._client.Slice( - _unwrap_data_handle(operand), start_indices, limit_indices, - strides)) + return self._client.Slice(operand, start_indices, limit_indices, strides) def SliceInDim(self, operand, start_index, limit_index, stride, dimno): """Enqueues a slice-in-dimension operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. + operand: LocalOp for the N dimensional array to be sliced. start_index: an integer containing the start index of the slice. limit_index: an integer containing the end index of the slice. stride: an integer containing the stride size for the slice. dimno: an integer indicating the dimension along which to slice. Returns: - A ComputationDataHandle representing the added Slice op. + A LocalOp representing the added Slice op. """ - return _wrap_data_handle( - self._client.SliceInDim( - _unwrap_data_handle(operand), start_index, limit_index, stride, - dimno)) + return self._client.SliceInDim(operand, start_index, limit_index, stride, + dimno) def DynamicSlice(self, operand, start_indices, slice_sizes): """Enqueues a slice op with dynamic start indices onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. - start_indices: ComputationDataHandle for the 1D array of N integers + operand: LocalOp for the N dimensional array to be sliced. + start_indices: LocalOp for the 1D array of N integers containing the starting indices of the slice. slice_sizes: iterable of N integers containing the slice sizes in each dimension. Returns: - A ComputationDataHandle representing the added DynamicSlice op. + A LocalOp representing the added DynamicSlice op. """ - return _wrap_data_handle( - self._client.DynamicSlice( - _unwrap_data_handle(operand), - _unwrap_data_handle(start_indices), - slice_sizes)) + return self._client.DynamicSlice(operand, start_indices, slice_sizes) def DynamicUpdateSlice(self, operand, update, start_indices): """Enqueues a dynamic update slice operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be updated. + operand: LocalOp for the N dimensional array to be updated. update: N dimensional array comprising the slice update. start_indices: Rank-1 array of N integers comprising the starting indices of the slice along each dimension. Returns: - A ComputationDataHandle representing the added DynamicUpdateSlice op. + A LocalOp representing the added DynamicUpdateSlice op. """ - return _wrap_data_handle( - self._client.DynamicUpdateSlice( - _unwrap_data_handle(operand), - _unwrap_data_handle(update), - _unwrap_data_handle(start_indices))) + return self._client.DynamicUpdateSlice(operand, update, start_indices) def Tuple(self, *ops): """Enqueues a tuple operation onto the computation. Args: - ops: a sequence of tuple operands (each a ComputationDataHandle). + ops: a sequence of tuple operands (each a LocalOp). Returns: - A ComputationDataHandle representing the added Tuple op. + A LocalOp representing the added Tuple op. """ - return _wrap_data_handle(self._client.Tuple(_unwrap_data_handles(ops))) + return self._client.Tuple(ops) def GetTupleElement(self, tup, index): """Enqueues a 'get tuple element' operation onto the computation. Args: - tup: the tuple operand (a ComputationDataHandle). + tup: the tuple operand (a LocalOp). index: numeric index to select from the tuple. Returns: - A ComputationDataHandle representing the added GetTupleElement op. + A LocalOp representing the added GetTupleElement op. """ - return _wrap_data_handle( - self._client.GetTupleElement(_unwrap_data_handle(tup), index)) + return self._client.GetTupleElement(tup, index) def Call(self, computation_to_apply, operands): """Enqueues a call operation onto the computation. Args: computation_to_apply: a Computation object. - operands: an iterable of ComputationDataHandle. The number and types of + operands: an iterable of LocalOp. The number and types of operands must match the arity of computation_to_apply. Returns: - A ComputationDataHandle representing the added call op. + A LocalOp representing the added call op. """ - return _wrap_data_handle( - self._client.Call(computation_to_apply.c_local_computation, - _unwrap_data_handles(operands))) + return self._client.Call(computation_to_apply.c_local_computation, operands) def Map(self, operands, computation_to_apply, dimensions, static_operands=()): """Enqueues a map operation onto the computation. Args: - operands: an iterable of ComputationDataHandle. + operands: an iterable of LocalOp. computation_to_apply: a Computation object. dimensions: dimensions over which to apply map the function. static_operands: auxiliary arguments passed to the applied computation. Returns: - A ComputationDataHandle representing the added Map op. + A LocalOp representing the added Map op. """ - return _wrap_data_handle( - self._client.Map( - _unwrap_data_handles(operands), - computation_to_apply.c_local_computation, - dimensions, - _unwrap_data_handles(static_operands))) + return self._client.Map(operands, computation_to_apply.c_local_computation, + dimensions, static_operands) def Reduce(self, operand, init_value, computation_to_apply, dimensions): """Enqueues a reduction operation onto the computation. Args: - operand: reduction operand (ComputationDataHandle). - init_value: reduction initial value (ComputationDataHandle). + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). computation_to_apply: a Computation object - binary reduction function. dimensions: sequence of dimensions (integers) to reduce on. Returns: - A ComputationDataHandle representing the added Reduce op. + A LocalOp representing the added Reduce op. """ - return _wrap_data_handle( - self._client.Reduce( - _unwrap_data_handle(operand), - _unwrap_data_handle(init_value), - computation_to_apply.c_local_computation, - dimensions)) + return self._client.Reduce(operand, init_value, + computation_to_apply.c_local_computation, + dimensions) def ReduceWindow(self, operand, init_value, computation_to_apply, window_dimensions, window_strides, padding): """Enqueues a windowed reduction operation onto the computation. Args: - operand: reduction operand (ComputationDataHandle). - init_value: reduction initial value (ComputationDataHandle). + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). computation_to_apply: a binary reduction function (Computation). window_dimensions: dimensions of window (sequence of integers). window_strides: strides for window (sequence of integers). padding: PaddingType representing either 'SAME' or 'VALID' padding. Returns: - A ComputationDataHandle representing the added ReduceWindow op. + A LocalOp representing the added ReduceWindow op. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(operand).dimensions(), window_dimensions, window_strides) - return _wrap_data_handle( - self._client.ReduceWindowWithGeneralPadding( - _unwrap_data_handle(operand), - _unwrap_data_handle(init_value), - computation_to_apply.c_local_computation, - window_dimensions, window_strides, pads)) + return self._client.ReduceWindowWithGeneralPadding( + operand, init_value, computation_to_apply.c_local_computation, + window_dimensions, window_strides, pads) def RngNormal(self, mu, sigma, dims): """Enqueues an RngNormal operation onto the computation. Args: - mu: A ComputationDataHandle to an F32 scalar specifying the mean. - sigma: A ComputationDataHandle to an F32 scalar specifying the standard + mu: A LocalOp to an F32 scalar specifying the mean. + sigma: A LocalOp to an F32 scalar specifying the standard deviation. dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a ComputationDataHandle to the generated array of F32 values. + Returns: a LocalOp to the generated array of F32 values. """ shape = Shape.array_shape(self.GetShape(mu).element_type(), dims) - return _wrap_data_handle( - self._client.RngNormal( - _unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape)) + return self._client.RngNormal(mu, sigma, shape) def RngUniform(self, a, b, dims): """Enqueues an RngUniform operation onto the computation. Args: - a: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + a: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of b) specifying the low end of the interval [a, b) over which values are generated. - b: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + b: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of a) specifying the high end of the interval [a, b) over which values are generated. dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a ComputationDataHandle to the generated array of values with the + Returns: a LocalOp to the generated array of values with the same numeric type (F32, S32, or U32) as the arguments a and b. """ shape = Shape.array_shape(self.GetShape(a).element_type(), dims) - return _wrap_data_handle( - self._client.RngUniform( - _unwrap_data_handle(a), _unwrap_data_handle(b), shape)) + return self._client.RngUniform(a, b, shape) def While(self, cond, body, init): """Enqueues a While operation onto the computation. @@ -1044,112 +973,105 @@ class ComputationBuilder(object): Args: cond: a Computation for the loop condition, which has type T -> PRED body: a Computation for the loop body, which has type T -> T - init: a ComputationDataHandle for the initial parameter, which has type T + init: a LocalOp for the initial parameter, which has type T - Returns: a ComputationDataHandle representing the While operation. + Returns: a LocalOp representing the While operation. """ - return _wrap_data_handle( - self._client.While(cond.c_local_computation, - body.c_local_computation, - _unwrap_data_handle(init))) + return self._client.While(cond.c_local_computation, + body.c_local_computation, init) def Conditional(self, pred, true_operand, true_computation, false_operand, false_computation): """Enqueues a Conditional operation onto the computation. Args: - predicate: a ComputationDataHandle to test, which has scalar type PRED - true_operand: a ComputationDataHandle of type T_0 + predicate: a LocalOp to test, which has scalar type PRED + true_operand: a LocalOp of type T_0 true_computation: a Computation to apply to true_operand, type T_0 -> S false_operand: a ComputationDatahandle of type T_1 false_computation: a Computation to apply to false_operand, type T_1 -> S - Returns: a ComputationDataHandle representing the Conditional operation. + Returns: a LocalOp representing the Conditional operation. """ - return _wrap_data_handle( - self._client.Conditional( - _unwrap_data_handle(pred), _unwrap_data_handle(true_operand), - true_computation.c_local_computation, - _unwrap_data_handle(false_operand), - false_computation.c_local_computation)) + return self._client.Conditional( + pred, true_operand, true_computation.c_local_computation, false_operand, + false_computation.c_local_computation) - def IsConstant(self, operand, num_parameters=0): - """Enqueues an IsConstant operation onto the computation. + def IsConstant(self, operand): + """Checks whether the given operand is a compile-time constant. Args: operand: a ComputationDataHandle to test. - num_parameters: optional int, number of computation parameters to treat as - constant (default 0). Returns: bool indicating whether `operand` is a compile-time constant, - meaning its value does not depend on parameters with index greater than or - equal to `num_parameters`. + meaning its value does not depend on any parametersor, or on stateful + operators such as `RngNormal` or `Infeed`. + """ + return self._client.IsConstant(operand) + + def BuildConstantSubGraph(self, operand): + """Builds a constant sub graph. + + Args: + operand: a LocalOp to test. + Returns: a LocalComputation that is rooted on the given `operand` which is a + compile-time constant. """ - return self._client.IsConstant(_unwrap_data_handle(operand), num_parameters) + return self._client.BuildConstantSubGraph(operand) def Dot(self, lhs, rhs): """Enqueues a dot operation onto the computation. Args: - lhs: ComputationDataHandle for the rank 1 or rank 2 left-hand-side array. - rhs: ComputationDataHandle for the rank 1 or rank 2 right-hand-side array. + lhs: LocalOp for the rank 1 or rank 2 left-hand-side array. + rhs: LocalOp for the rank 1 or rank 2 right-hand-side array. - Returns: a ComputationDataHandle representing the Dot operation. + Returns: a LocalOp representing the Dot operation. """ - return _wrap_data_handle( - self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs))) + return self._client.Dot(lhs, rhs) def DotGeneral(self, lhs, rhs, dimension_numbers): """Enqueues a general dot operation onto the computation. Args: - lhs: ComputationDataHandle for the left-hand-side array. - rhs: ComputationDataHandle for the right-hand-side array. + lhs: LocalOp for the left-hand-side array. + rhs: LocalOp for the right-hand-side array. dimension_numbers: either an xla_data_pb2.DotDimensionNumbers or a nested tuple ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of integers representing the dimensions to treat as contracting dimensions and batch dimensions on each input operand. - Returns: a ComputationDataHandle representing the DotGeneral operation. + Returns: a LocalOp representing the DotGeneral operation. """ if not isinstance(dimension_numbers, xla_data_pb2.DotDimensionNumbers): dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) - return _wrap_data_handle( - self._client.DotGeneral( - _unwrap_data_handle(lhs), _unwrap_data_handle(rhs), - dimension_numbers)) + return self._client.DotGeneral(lhs, rhs, dimension_numbers) def Conv(self, lhs, rhs, window_strides, padding): """Enqueues a Conv operation onto the computation. Args: - lhs: ComputationDataHandle for the rank N+2 array of inputs. - rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of integer kernel strides. padding: PaddingType representing either 'SAME' or 'VALID' padding. - Returns: a ComputationDataHandle representing the Conv operation. + Returns: a LocalOp representing the Conv operation. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(lhs).dimensions()[2:], self.GetShape(rhs).dimensions()[2:], window_strides) dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - return _wrap_data_handle( - self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), - _unwrap_data_handle(rhs), - window_strides, - pads, - (), - (), - dimension_numbers)) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (), + (), dimension_numbers) def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation): """Enqueues a ConvWithGeneralPadding operation onto the computation. Args: - lhs: ComputationDataHandle for the rank N+2 array of inputs. - rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of kernel strides. padding: length-N array-like of pairs of integers of (low, high) padding. lhs_dilation: length-N array-like of dilation factors. @@ -1159,14 +1081,9 @@ class ComputationBuilder(object): A ComputationdataHandle representing the added ConvWithGeneralPadding op. """ dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - return _wrap_data_handle( - self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), - _unwrap_data_handle(rhs), - window_strides, - padding, - lhs_dilation, - rhs_dilation, - dimension_numbers)) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers) def _GetConvDimensionNumbers(self, num_spatial_dims): """Create ConvolutionDimensionNumbers proto for convolutions.""" @@ -1196,15 +1113,14 @@ def _forward_methods_to_local_builder(): """Generate a forwarding method that wraps/unwraps data handles.""" def forward(self, *args, **kwargs): - unwrapped_args = [_unwrap_data_handle(arg) for arg in args] + arg_list = list(args) - if is_binop and len(unwrapped_args) < 3: - unwrapped_args.append(kwargs.get('broadcast_dimensions', ())) + if is_binop and len(arg_list) < 3: + arg_list.append(kwargs.get('broadcast_dimensions', ())) - return _wrap_data_handle( - target_method( - self._client, # pylint: disable=protected-access - *unwrapped_args)) + return target_method( + self._client, # pylint: disable=protected-access + *arg_list) return forward diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 28d6a8c3fe85fa4179bf2f41c82ad4eb93a045fe..2698ba7d79e246530b6b486d3e3bc8bf101c891e 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -330,13 +330,14 @@ class ReferenceUtil { return result; } - // Slices with modulo-wrapping. + // Slices with index clamping template - static std::vector ModSlice1D(const tensorflow::gtl::ArraySlice& input, - int64 start, int64 size) { + static std::vector ClampSlice1D( + const tensorflow::gtl::ArraySlice& input, int64 start, int64 size) { + start = std::min(std::max(0, start), input.size() - size); std::vector result; for (int64 i = 0; i < size; ++i) { - result.push_back(input[(start + i) % input.size()]); + result.push_back(input[(start + i)]); } return result; } diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index 10997c0719dfb80efc7b855c7888500caeb1591b..313f11a9a957155eb277dc02ba5d2565c87e0235 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -101,8 +101,8 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) { TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer( computation, {}, nullptr)); - LiteralTestUtil::ExpectNear(*expected_literal, *result_literal, - ErrorSpec(0.0001)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal, + ErrorSpec(0.0001))); } } // namespace diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc index ffb72fc73c5bc1ad6e648fb3d772eb5749700dc0..5f4dc6bd08f18b50e60b173432d3d305759bccea 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service.cc @@ -27,8 +27,8 @@ namespace xla { return std::move(grpc_service); } -::grpc::Status DelegateRPC(std::function op) { - tensorflow::Status s = op(); +::grpc::Status DelegateRPC(std::function op) { + Status s = op(); return tensorflow::ToGrpcStatus(s); } diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.cc b/tensorflow/compiler/xla/rpc/grpc_stub.cc index e1f2b0abe39b10dd82b700941748bc4f4e8cb2f8..620ac6cec4f76d938e57e87849066df59514938a 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.cc +++ b/tensorflow/compiler/xla/rpc/grpc_stub.cc @@ -20,53 +20,49 @@ namespace xla { GRPCStub::~GRPCStub() = default; -tensorflow::Status MakeRPC( +Status MakeRPC( const std::function<::grpc::Status(::grpc::ClientContext*)>& rpc_method) { ::grpc::ClientContext context; ::grpc::Status s = rpc_method(&context); return tensorflow::FromGrpcStatus(s); } -tensorflow::Status GRPCStub::TransferToClient( - const TransferToClientRequest* request, - TransferToClientResponse* response) { +Status GRPCStub::TransferToClient(const TransferToClientRequest* request, + TransferToClientResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferToClient(context, *request, response); }); } -tensorflow::Status GRPCStub::TransferToServer( - const TransferToServerRequest* request, - TransferToServerResponse* response) { +Status GRPCStub::TransferToServer(const TransferToServerRequest* request, + TransferToServerResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferToServer(context, *request, response); }); } -tensorflow::Status GRPCStub::TransferToInfeed( - const TransferToInfeedRequest* request, - TransferToInfeedResponse* response) { +Status GRPCStub::TransferToInfeed(const TransferToInfeedRequest* request, + TransferToInfeedResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferToInfeed(context, *request, response); }); } -tensorflow::Status GRPCStub::TransferFromOutfeed( - const TransferFromOutfeedRequest* request, - TransferFromOutfeedResponse* response) { +Status GRPCStub::TransferFromOutfeed(const TransferFromOutfeedRequest* request, + TransferFromOutfeedResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferFromOutfeed(context, *request, response); }); } -tensorflow::Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, - ResetDeviceResponse* response) { +Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, + ResetDeviceResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->ResetDevice(context, *request, response); }); } -tensorflow::Status GRPCStub::LoadComputationSnapshot( +Status GRPCStub::LoadComputationSnapshot( const LoadComputationSnapshotRequest* request, LoadComputationSnapshotResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -74,28 +70,28 @@ tensorflow::Status GRPCStub::LoadComputationSnapshot( }); } -tensorflow::Status GRPCStub::Execute(const ExecuteRequest* request, - ExecuteResponse* response) { +Status GRPCStub::Execute(const ExecuteRequest* request, + ExecuteResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->Execute(context, *request, response); }); } -tensorflow::Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, - ExecuteResponse* response) { +Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, + ExecuteResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->ExecuteGraph(context, *request, response); }); } -tensorflow::Status GRPCStub::ExecuteParallel( - const ExecuteParallelRequest* request, ExecuteParallelResponse* response) { +Status GRPCStub::ExecuteParallel(const ExecuteParallelRequest* request, + ExecuteParallelResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->ExecuteParallel(context, *request, response); }); } -tensorflow::Status GRPCStub::ExecuteGraphParallel( +Status GRPCStub::ExecuteGraphParallel( const ExecuteGraphParallelRequest* request, ExecuteParallelResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -103,38 +99,35 @@ tensorflow::Status GRPCStub::ExecuteGraphParallel( }); } -tensorflow::Status GRPCStub::ExecuteAsync(const ExecuteAsyncRequest* request, - ExecuteAsyncResponse* response) { +Status GRPCStub::ExecuteAsync(const ExecuteAsyncRequest* request, + ExecuteAsyncResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->ExecuteAsync(context, *request, response); }); } -tensorflow::Status GRPCStub::WaitForExecution( - const WaitForExecutionRequest* request, - WaitForExecutionResponse* response) { +Status GRPCStub::WaitForExecution(const WaitForExecutionRequest* request, + WaitForExecutionResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->WaitForExecution(context, *request, response); }); } -tensorflow::Status GRPCStub::DeconstructTuple( - const DeconstructTupleRequest* request, - DeconstructTupleResponse* response) { +Status GRPCStub::DeconstructTuple(const DeconstructTupleRequest* request, + DeconstructTupleResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->DeconstructTuple(context, *request, response); }); } -tensorflow::Status GRPCStub::GetComputationStats( - const ComputationStatsRequest* request, - ComputationStatsResponse* response) { +Status GRPCStub::GetComputationStats(const ComputationStatsRequest* request, + ComputationStatsResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->GetComputationStats(context, *request, response); }); } -tensorflow::Status GRPCStub::GetComputationGraphStats( +Status GRPCStub::GetComputationGraphStats( const ComputationGraphStatsRequest* request, ComputationStatsResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -142,81 +135,77 @@ tensorflow::Status GRPCStub::GetComputationGraphStats( }); } -tensorflow::Status GRPCStub::GetComputationShape( - const GetComputationShapeRequest* request, - GetComputationShapeResponse* response) { +Status GRPCStub::GetComputationShape(const GetComputationShapeRequest* request, + GetComputationShapeResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->GetComputationShape(context, *request, response); }); } -tensorflow::Status GRPCStub::GetShape(const GetShapeRequest* request, - GetShapeResponse* response) { +Status GRPCStub::GetShape(const GetShapeRequest* request, + GetShapeResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->GetShape(context, *request, response); }); } -tensorflow::Status GRPCStub::GetDeviceHandles( - const GetDeviceHandlesRequest* request, - GetDeviceHandlesResponse* response) { +Status GRPCStub::GetDeviceHandles(const GetDeviceHandlesRequest* request, + GetDeviceHandlesResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->GetDeviceHandles(context, *request, response); }); } -tensorflow::Status GRPCStub::CreateChannelHandle( - const CreateChannelHandleRequest* request, - CreateChannelHandleResponse* response) { +Status GRPCStub::CreateChannelHandle(const CreateChannelHandleRequest* request, + CreateChannelHandleResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->CreateChannelHandle(context, *request, response); }); } // Methods used by ComputationBuilder. -tensorflow::Status GRPCStub::Computation(const ComputationRequest* request, - ComputationResponse* response) { +Status GRPCStub::Computation(const ComputationRequest* request, + ComputationResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->Computation(context, *request, response); }); } -tensorflow::Status GRPCStub::Op(const OpRequest* request, - OpResponse* response) { +Status GRPCStub::Op(const OpRequest* request, OpResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->CreateOp(context, *request, response); }); } -tensorflow::Status GRPCStub::GetLocalShape(const GetLocalShapeRequest* request, - GetLocalShapeResponse* response) { +Status GRPCStub::GetLocalShape(const GetLocalShapeRequest* request, + GetLocalShapeResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->GetLocalShape(context, *request, response); }); } -tensorflow::Status GRPCStub::SetReturnValue( - const SetReturnValueRequest* request, SetReturnValueResponse* responses) { +Status GRPCStub::SetReturnValue(const SetReturnValueRequest* request, + SetReturnValueResponse* responses) { return MakeRPC([this, request, responses](::grpc::ClientContext* context) { return grpc_stub_->SetReturnValue(context, *request, responses); }); } -tensorflow::Status GRPCStub::IsConstant(const IsConstantRequest* request, - IsConstantResponse* response) { +Status GRPCStub::IsConstant(const IsConstantRequest* request, + IsConstantResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->IsConstant(context, *request, response); }); } -tensorflow::Status GRPCStub::ComputeConstant( - const ComputeConstantRequest* request, ComputeConstantResponse* response) { +Status GRPCStub::ComputeConstant(const ComputeConstantRequest* request, + ComputeConstantResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->ComputeConstant(context, *request, response); }); } -tensorflow::Status GRPCStub::ComputeConstantGraph( +Status GRPCStub::ComputeConstantGraph( const ComputeConstantGraphRequest* request, ComputeConstantResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -225,17 +214,16 @@ tensorflow::Status GRPCStub::ComputeConstantGraph( } // Methods used by Computation. -tensorflow::Status GRPCStub::SnapshotComputation( - const SnapshotComputationRequest* request, - SnapshotComputationResponse* response) { +Status GRPCStub::SnapshotComputation(const SnapshotComputationRequest* request, + SnapshotComputationResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->SnapshotComputation(context, *request, response); }); } // Methods used by GlobalData. -tensorflow::Status GRPCStub::Unregister(const UnregisterRequest* request, - UnregisterResponse* response) { +Status GRPCStub::Unregister(const UnregisterRequest* request, + UnregisterResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->Unregister(context, *request, response); }); diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.h b/tensorflow/compiler/xla/rpc/grpc_stub.h index fd9810d4f1a5e084b73e83007ea7f9f8b0462c72..5906d45769b5749b0c590dbc0e1972077dc3e7ba 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.h +++ b/tensorflow/compiler/xla/rpc/grpc_stub.h @@ -28,105 +28,90 @@ class GRPCStub : public ServiceInterface { explicit GRPCStub(grpc::XlaService::Stub* stub) : grpc_stub_(stub) {} ~GRPCStub() override; - tensorflow::Status TransferToClient( - const TransferToClientRequest* arg, - TransferToClientResponse* result) override; + Status TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) override; - tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, - TransferToServerResponse* result) override; + Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) override; - tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override; + Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override; - tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override; + Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override; - tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override; + Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override; - tensorflow::Status LoadComputationSnapshot( + Status LoadComputationSnapshot( const LoadComputationSnapshotRequest* request, LoadComputationSnapshotResponse* result) override; - tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) override; + Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; - tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* request, - ExecuteResponse* response) override; + Status ExecuteGraph(const ExecuteGraphRequest* request, + ExecuteResponse* response) override; - tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override; + Status ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) override; - tensorflow::Status ExecuteGraphParallel( - const ExecuteGraphParallelRequest* request, - ExecuteParallelResponse* response) override; + Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* request, + ExecuteParallelResponse* response) override; - tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; + Status ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) override; - tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) override; + Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override; - tensorflow::Status DeconstructTuple( - const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) override; + Status DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) override; - tensorflow::Status GetComputationStats( - const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; + Status GetComputationStats(const ComputationStatsRequest* arg, + ComputationStatsResponse* result) override; - tensorflow::Status GetComputationGraphStats( - const ComputationGraphStatsRequest* request, - ComputationStatsResponse* response) override; + Status GetComputationGraphStats(const ComputationGraphStatsRequest* request, + ComputationStatsResponse* response) override; - tensorflow::Status GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; + Status GetComputationShape(const GetComputationShapeRequest* arg, + GetComputationShapeResponse* result) override; - tensorflow::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) override; + Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) override; - tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override; + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override; - tensorflow::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) override; + Status CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) override; // Methods used by ComputationBuilder. - tensorflow::Status Computation(const ComputationRequest* arg, - ComputationResponse* result) override; + Status Computation(const ComputationRequest* arg, + ComputationResponse* result) override; - tensorflow::Status Op(const OpRequest* arg, OpResponse* result) override; - tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; + Status Op(const OpRequest* arg, OpResponse* result) override; + Status GetLocalShape(const GetLocalShapeRequest* arg, + GetLocalShapeResponse* result) override; - tensorflow::Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; + Status SetReturnValue(const SetReturnValueRequest* arg, + SetReturnValueResponse* results) override; - tensorflow::Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) override; + Status IsConstant(const IsConstantRequest* arg, + IsConstantResponse* result) override; - tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; + Status ComputeConstant(const ComputeConstantRequest* arg, + ComputeConstantResponse* result) override; - tensorflow::Status ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, - ComputeConstantResponse* result) override; + Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) override; // Methods used by Computation. - tensorflow::Status SnapshotComputation( - const SnapshotComputationRequest* ag, - SnapshotComputationResponse* result) override; + Status SnapshotComputation(const SnapshotComputationRequest* ag, + SnapshotComputationResponse* result) override; // Methods used by GlobalData. - tensorflow::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) override; + Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) override; grpc::XlaService::Stub* service() { return grpc_stub_; } diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 9c362d8cad4642e534430becd9374351d51bf297..d1722644c72646538dab77899b79d25056f2f2bf 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -12,6 +12,7 @@ package_group( ], ) +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") @@ -26,6 +27,7 @@ xla_proto_library( xla_proto_library( name = "hlo_proto", srcs = ["hlo.proto"], + visibility = ["//visibility:public"], deps = ["//tensorflow/compiler/xla:xla_data_proto"], ) @@ -200,7 +202,22 @@ tf_cc_test( cc_library( name = "hlo_evaluator", - srcs = ["hlo_evaluator.cc"], + srcs = [ + "hlo_evaluator.cc", + "hlo_evaluator_typed_visitor.h", + "hlo_evaluator_typed_visitor_bfloat16.cc", + "hlo_evaluator_typed_visitor_bool.cc", + "hlo_evaluator_typed_visitor_complex64.cc", + "hlo_evaluator_typed_visitor_double.cc", + "hlo_evaluator_typed_visitor_float.cc", + "hlo_evaluator_typed_visitor_half.cc", + "hlo_evaluator_typed_visitor_int32.cc", + "hlo_evaluator_typed_visitor_int64.cc", + "hlo_evaluator_typed_visitor_int8.cc", + "hlo_evaluator_typed_visitor_uint32.cc", + "hlo_evaluator_typed_visitor_uint64.cc", + "hlo_evaluator_typed_visitor_uint8.cc", + ], hdrs = ["hlo_evaluator.h"], deps = [ ":hlo", @@ -370,6 +387,7 @@ tf_cc_test( ":hlo_matchers", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -743,6 +761,23 @@ cc_library( ], ) +tf_cc_test( + name = "shaped_buffer_test", + srcs = ["shaped_buffer_test.cc"], + deps = [ + ":cpu_plugin", + ":device_memory_allocator", + ":platform_util", + ":shaped_buffer", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:ptr_util", + "//tensorflow/core:test", + ], +) + cc_library( name = "executable", srcs = ["executable.cc"], @@ -837,7 +872,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -918,33 +952,6 @@ tf_cc_test( ], ) -cc_library( - name = "liveness_util", - srcs = ["liveness_util.cc"], - hdrs = ["liveness_util.h"], - deps = [ - ":hlo", - ":hlo_dataflow_analysis", - ":logical_buffer", - ":tuple_points_to_analysis", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - ], -) - -tf_cc_test( - name = "liveness_util_test", - srcs = ["liveness_util_test.cc"], - deps = [ - ":hlo", - ":liveness_util", - ":tuple_points_to_analysis", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - ], -) - cc_library( name = "buffer_liveness", srcs = [ @@ -956,7 +963,6 @@ cc_library( deps = [ ":hlo", ":hlo_ordering", - ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -993,6 +999,7 @@ cc_library( ], deps = [ ":buffer_liveness", + ":buffer_value_containers", ":heap_simulator", ":hlo", ":hlo_proto", @@ -1048,7 +1055,6 @@ cc_library( ":hlo_dataflow_analysis", ":hlo_proto", ":hlo_value", - ":liveness_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1081,11 +1087,11 @@ cc_library( srcs = ["heap_simulator.cc"], hdrs = ["heap_simulator.h"], deps = [ + ":buffer_value", + ":buffer_value_containers", ":hlo", ":hlo_ordering", ":hlo_proto", - ":liveness_util", - ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -1101,7 +1107,7 @@ tf_cc_test( ":heap_simulator", ":hlo", ":hlo_ordering", - ":logical_buffer", + ":hlo_value", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:status_macros", @@ -1251,13 +1257,11 @@ cc_library( deps = [ ":hlo", ":hlo_pass", - ":hlo_query", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], @@ -1342,6 +1346,42 @@ tf_cc_test( ], ) +cc_library( + name = "batch_dot_simplification", + srcs = ["batch_dot_simplification.cc"], + hdrs = ["batch_dot_simplification.h"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "batch_dot_simplification_test", + srcs = ["batch_dot_simplification_test.cc"], + deps = [ + ":batch_dot_simplification", + ":hlo", + ":hlo_matchers", + ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "gather_expander_test", srcs = ["gather_expander_test.cc"], @@ -1704,6 +1744,8 @@ tf_cc_test( ":hlo_execution_profile", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:lib", ], ) @@ -1768,6 +1810,17 @@ cc_library( ], ) +cc_library( + name = "buffer_value_containers", + hdrs = ["buffer_value_containers.h"], + deps = [ + ":buffer_value", + ":logical_buffer", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + cc_library( name = "logical_buffer", srcs = ["logical_buffer.cc"], @@ -1842,6 +1895,44 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_liveness_analysis", + srcs = ["hlo_liveness_analysis.cc"], + hdrs = ["hlo_liveness_analysis.h"], + deps = [ + ":call_graph", + ":hlo", + ":hlo_value", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "hlo_liveness_analysis_test", + srcs = ["hlo_liveness_analysis_test.cc"], + deps = [ + ":hlo", + ":hlo_liveness_analysis", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_buffer", srcs = ["hlo_buffer.cc"], @@ -1978,10 +2069,12 @@ cc_library( deps = [ ":computation_layout", ":hlo", + ":hlo_dce", ":hlo_graph_dumper", ":hlo_pass", ":logical_buffer", ":tuple_points_to_analysis", + ":tuple_simplifier", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -2005,7 +2098,6 @@ cc_library( ":hlo_graph_dumper", ":hlo_ordering", ":hlo_pass", - ":liveness_util", ":logical_buffer", ":tuple_simplifier", "//tensorflow/compiler/xla:status_macros", @@ -2052,6 +2144,24 @@ cc_library( ], ) +cc_library( + name = "hlo_module_dce", + srcs = ["hlo_module_dce.cc"], + hdrs = ["hlo_module_dce.h"], + deps = [ + ":hlo", + ":hlo_dce", + ":hlo_liveness_analysis", + ":hlo_pass", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + cc_library( name = "hlo_verifier", srcs = ["hlo_verifier.cc"], @@ -2094,7 +2204,6 @@ cc_library( ":hlo_dce", ":hlo_ordering", ":hlo_scheduling", - ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -2142,6 +2251,27 @@ tf_cc_test( ], ) +tf_cc_test( + name = "hlo_module_dce_test", + srcs = ["hlo_module_dce_test.cc"], + deps = [ + ":hlo", + ":hlo_module_dce", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "layout_assignment_test", srcs = ["layout_assignment_test.cc"], @@ -2233,6 +2363,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -2299,8 +2430,14 @@ tf_cc_test( cc_library( name = "device_memory_allocator", - srcs = ["device_memory_allocator.cc"], - hdrs = ["device_memory_allocator.h"], + srcs = [ + "device_memory_allocator.cc", + "owning_device_memory.cc", + ], + hdrs = [ + "device_memory_allocator.h", + "owning_device_memory.h", + ], deps = [ "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -2335,6 +2472,24 @@ cc_library( ], ) +xla_test( + name = "elemental_ir_emitter_test", + srcs = ["elemental_ir_emitter_test.cc"], + backends = [ + "cpu", + "gpu", + ], + deps = [ + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + cc_library( name = "hlo_module_config", srcs = ["hlo_module_config.cc"], @@ -2458,6 +2613,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], ) @@ -2467,6 +2623,7 @@ tf_cc_test( srcs = ["transpose_folding_test.cc"], deps = [ ":hlo", + ":hlo_matchers", ":shape_inference", ":transpose_folding", "//tensorflow/compiler/xla:literal_util", @@ -2478,6 +2635,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -2510,7 +2668,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -2620,7 +2777,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", - "@com_google_absl//absl/memory", ], ) @@ -2762,3 +2918,30 @@ cc_library( "//tensorflow/core:lib", ], ) + +cc_library( + name = "indexed_array_analysis", + srcs = ["indexed_array_analysis.cc"], + hdrs = ["indexed_array_analysis.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/core:ptr_util", + ], +) + +tf_cc_test( + name = "indexed_array_analysis_test", + srcs = ["indexed_array_analysis_test.cc"], + deps = [ + ":hlo_matchers", + ":indexed_array_analysis", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 8e785de68cb1fbe4ce9fd58a661bdc208725483b..f732ed8f398c4699bd5247dc7fa1e9677340dcae 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -92,26 +92,6 @@ bool ReshapeIsBitcast( valid_bitcast_callback(operand->shape(), reshape->shape()); } -// Adds a scalar computation to the module to enable optimizations with dot -// converting into reduction. -HloComputation* CreateScalarBinaryComputation(HloModule* module, - PrimitiveType primitive_type, - HloOpcode opcode) { - HloComputation::Builder b("scalar_computation"); - auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs")); - auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs")); - auto scalar_op = b.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), - opcode, scalar_lhs, scalar_rhs)); - HloComputation* scalar_computation = - module->AddEmbeddedComputation(b.Build(scalar_op)); - return scalar_computation; -} - -} // namespace - // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain // algebraic expressions to simplified forms. Note: This only supports // simplifications that simply look at the operands of an instruction. For the @@ -220,8 +200,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { HloInstruction* zero = computation_->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloComputation* AddReduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); + HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( shape, hlo, zero, {dim}, AddReduce_computation)); @@ -291,6 +270,26 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped); + StatusOr OptimizeDotOfGather(HloInstruction* dot); + + HloComputation* GetOrCreateScalarAddComputation() { + if (scalar_add_computation_) { + return scalar_add_computation_; + } + + HloComputation::Builder b("scalar_add_computation"); + Shape shape = ShapeUtil::MakeShape(F32, {}); + auto scalar_lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, shape, "scalar_rhs")); + auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); + scalar_add_computation_ = + computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + return scalar_add_computation_; + } + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -309,8 +308,13 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Disable convolution simplification on platforms where it causes a slowdown. bool enable_conv_simplification_; + + // Cached computation for adding two scalar F32. + HloComputation* scalar_add_computation_ = nullptr; }; +} // namespace + bool AlgebraicSimplifierVisitor::Run( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, @@ -499,13 +503,13 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( } static HloInstruction* BuildTupleConstant(HloComputation* computation, - const Literal& literal) { + const LiteralSlice& literal) { if (ShapeUtil::IsTuple(literal.shape())) { std::vector elems; elems.reserve(ShapeUtil::TupleElementCount(literal.shape())); for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) { elems.push_back( - BuildTupleConstant(computation, LiteralView::Create(literal, {i}))); + BuildTupleConstant(computation, LiteralSlice(literal, {i}))); } return computation->AddInstruction(HloInstruction::CreateTuple(elems)); } else { @@ -912,6 +916,134 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( return add_result; } +StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( + HloInstruction* dot) { + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_contracting_dimensions_size() != 1 || + dnums.rhs_contracting_dimensions_size() != 1 || + dnums.lhs_batch_dimensions_size() != 0 || + dnums.rhs_batch_dimensions_size() != 0 || + dot->shape().dimensions_size() != 2) { // dot output 2D + VLOG(10) << "DotOfGather: Can only optimize 2D, non-batch dot operations."; + return nullptr; + } + + // Optimize either dot(DS(ctA), ctB)) or dot(ctB, DS(ctA)). + // Currently a Gather is a DynamicSlice. + auto is_dynamic_slice_constant_combination = + [](HloInstruction* a, HloInstruction* b, int a_contracting_dimension) { + // First operand is a DynamicSlice(Constant). + if (a->opcode() != HloOpcode::kDynamicSlice) { + return false; + } + auto* dynamic_slice_op = a->operand(0); + if (dynamic_slice_op->opcode() != HloOpcode::kConstant) { + return false; + } + // Second operand is a Constant. + if (b->opcode() != HloOpcode::kConstant) { + return false; + } + // The DynamicSlice output is a vector. + const Shape& dynamic_slice_shape = a->shape(); + if (dynamic_slice_shape.dimensions(1 - a_contracting_dimension) != 1) { + return false; + } + // Constant size is the same before and after slice in the contracting + // dimension, otherwise we either must precompute for all possible slice + // indices or dot is invalid. + const Shape& dynamic_slice_op_shape = dynamic_slice_op->shape(); + if (dynamic_slice_op_shape.dimensions(a_contracting_dimension) != + dynamic_slice_shape.dimensions(a_contracting_dimension)) { + return false; + } + return true; + }; + + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + int lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); + int rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); + + if (!is_dynamic_slice_constant_combination( + lhs, rhs, /*a_contracting_dimension=*/lhs_contracting_dimension) && + !is_dynamic_slice_constant_combination( + rhs, lhs, /*a_contracting_dimension=*/rhs_contracting_dimension)) { + VLOG(10) << "DotOfGather: Can only optimize dot(DS(ctA), ctB)) or " + "dot(ctB, DS(ctA)), where the two constants have equal " + "contracting dimensions."; + return nullptr; + } + + // LHS is DynamicSlice: + // input: dot(DS(ctA), ctB)) + // where DS(ctA) = DS({M x K}, {start, 0}, {1, K}) and ctB = {K x N}. + // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}. + // output: DS(dot(ctA, ctB)) + // => output dimensions: DS ({M x N}, {start, 0}, {1, N}) => {1 x N}. + + // RHS is DynamicSlice: + // input: dot(ctA, DS(ctB)) + // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, start}, {K, 1}). + // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}. + // output: DS(dot(ctA, ctB)) + // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}. + + bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice; + + // ctA: + HloInstruction* left_operand = + lhs_is_dynamic_slice ? lhs->mutable_operand(0) : lhs; + // ctB: + HloInstruction* right_operand = + lhs_is_dynamic_slice ? rhs : rhs->mutable_operand(0); + // Build ctA x ctB. + const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension); + const int n = + right_operand->shape().dimensions(1 - rhs_contracting_dimension); + auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n}); + auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot( + memoized_shape, left_operand, right_operand, dnums)); + // Get pair {start, 0} or {0, start}. + HloInstruction* original_start_indices = + lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1); + // Position of start: + int index_of_non_zero_start = lhs_is_dynamic_slice + ? 1 - lhs_contracting_dimension + : 1 - rhs_contracting_dimension; + // Position of zero: + int index_of_zero_start = 1 - index_of_non_zero_start; + + // Slice out start and 0 components and reorder if necessary. + auto indices_type = original_start_indices->shape().element_type(); + Shape s_shape = ShapeUtil::MakeShape(indices_type, {1}); + Shape d_shape = ShapeUtil::MakeShape(indices_type, {2}); + HloInstruction* non_zero_start = + computation_->AddInstruction(HloInstruction::CreateSlice( + s_shape, original_start_indices, {index_of_non_zero_start}, + {index_of_non_zero_start + 1}, {1})); + HloInstruction* zero_start = + computation_->AddInstruction(HloInstruction::CreateSlice( + s_shape, original_start_indices, {index_of_zero_start}, + {index_of_zero_start + 1}, {1})); + HloInstruction* new_start_indices = + lhs_is_dynamic_slice + ? computation_->AddInstruction(HloInstruction::CreateConcatenate( + d_shape, {non_zero_start, zero_start}, 0)) + : computation_->AddInstruction(HloInstruction::CreateConcatenate( + d_shape, {zero_start, non_zero_start}, 0)); + + // Build DynamicSlice(ctA x ctB). + const int new_slice_m = lhs_is_dynamic_slice ? 1 : m; + const int new_slice_n = lhs_is_dynamic_slice ? n : 1; + auto* memoized_lookup = + computation_->AddInstruction(HloInstruction::CreateDynamicSlice( + dot->shape(), memoized_inst, new_start_indices, + {new_slice_m, new_slice_n})); + + return memoized_lookup; +} + Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); @@ -941,6 +1073,17 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return ReplaceInstruction(dot, dot_of_concat_optimized); } + // Simplify dot(ConstA, Gather(Index, ConstB)) to: + // Gather(Index, dot*(ConstA, ConstB)), where dot* is an appropriately + // batched version of dot. + TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized, + OptimizeDotOfGather(dot)); + if (dot_of_gather_optimized) { + VLOG(10) << "Replaced dot(constA, gather(i, constB)) with " + "gather(i, dot*(constA, constB))"; + return ReplaceInstruction(dot, dot_of_gather_optimized); + } + if (enable_dot_strength_reduction_ && !is_layout_sensitive_) { TF_ASSIGN_OR_RETURN(bool did_strength_reduction, HandleDotStrengthReduction(dot)); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index d0c99bf818cd54b897ae9da6f9c46862254d64e5..4e082877c776c35bab499c805fef7632765a3ee1 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2963,5 +2963,208 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation, DotOfConcatSimplificationTest, ::testing::ValuesIn(kDotOfConcatTestSpecs)); + +struct DotOfGatherTestSpec { + int64 m; + int64 k; + int64 n; + int s; // start index for dynamic slice on the non-contracting dimension + int64 lcd; // left contracting dimension + int64 rcd; // right contracting dimension + bool neg; // is negative testcase +}; + +class DotOfGatherSimplificationTest + : public HloVerifiedTestBase, + public ::testing::WithParamInterface {}; + +// input: dot(DS(ctA), ctB)) +// where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}. +// => input dimensions: dot({1 x K}, {K x N}) => {1 x N}. +// output: DS(dot(ctA, ctB)) +// => output dimensions: DS ({M x N}, {s, 0}, {1, N}) => {1 x N}. +TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { + HloComputation::Builder builder(TestName()); + + DotOfGatherTestSpec spec = GetParam(); + + ASSERT_LE(spec.s, spec.m); + + // For negative tests, increase k of the dynamic slice argument to prevent the + // optimization (constants ctA, ctB must have equal contracting dimensions). + int64 k_increase = spec.neg ? 5 : 0; + int64 lhs_rows = (spec.lcd == 0) ? (spec.k + k_increase) : spec.m; + int64 lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase); + Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols}); + auto* lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows, + /*cols=*/lhs_cols))); + + int32 start_row = (spec.lcd == 0) ? 0 : spec.s; + int32 start_col = (spec.lcd == 0) ? spec.s : 0; + const auto start_indices = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({start_row, start_col}))); + int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1; + int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k; + Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); + auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( + ds_shape, lhs, start_indices, {slice_row_size, slice_col_size})); + + int64 rhs_rows = (spec.rcd == 0) ? spec.k : spec.n; + int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k; + Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols}); + auto* rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows, + /*cols=*/rhs_cols))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(spec.lcd); + dot_dnums.add_rhs_contracting_dimensions(spec.rcd); + + int64 dot_row_size = 1; + int64 dot_col_size = spec.n; + Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, ds, rhs, dot_dnums)); + + auto computation = module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + ASSERT_TRUE(run_successful); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); + + if (spec.neg) { + EXPECT_NE(computation->root_instruction()->opcode(), + HloOpcode::kDynamicSlice); + } else { + EXPECT_THAT(computation->root_instruction(), + op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), + op::Concatenate())); + } +} + +// input: dot(ctA, DS(ctB)) +// where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, s}, {K, 1}). +// => input dimensions: dot({M x K}, {K x 1}) => {M x 1}. +// output: DS(dot(ctA, ctB)) +// => output dimensions: DS ({M x N}, {0, s}, {M, 1}) => {M x 1}. +TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { + HloComputation::Builder builder(TestName()); + + DotOfGatherTestSpec spec = GetParam(); + + ASSERT_LE(spec.s, spec.n); + + int64 lhs_rows = (spec.lcd == 0) ? spec.k : spec.m; + int64 lhs_cols = (spec.lcd == 0) ? spec.m : spec.k; + Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols}); + auto* lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows, + /*cols=*/lhs_cols))); + + // For negative tests increase k of the dynamic slice argument to prevent the + // optimization + int64 k_increase = spec.neg ? 5 : 0; + int64 rhs_rows = (spec.rcd == 0) ? (spec.k + k_increase) : spec.n; + int64 rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols}); + auto* rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows, + /*cols=*/rhs_cols))); + + int32 start_row = (spec.rcd == 0) ? 0 : spec.s; + int32 start_col = (spec.rcd == 0) ? spec.s : 0; + const auto start_indices = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({start_row, start_col}))); + int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1; + int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k; + Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); + auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( + ds_shape, rhs, start_indices, {slice_row_size, slice_col_size})); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(spec.lcd); + dot_dnums.add_rhs_contracting_dimensions(spec.rcd); + + int64 dot_row_size = spec.m; + int64 dot_col_size = 1; + Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, ds, dot_dnums)); + + auto computation = module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + ASSERT_TRUE(run_successful); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); + + if (spec.neg) { + EXPECT_NE(computation->root_instruction()->opcode(), + HloOpcode::kDynamicSlice); + } else { + EXPECT_THAT(computation->root_instruction(), + op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), + op::Concatenate())); + } +} + +std::vector DotOfGatherPositiveNegativeTests() { + std::vector positives = { + // "Classical dot", i.e. matrix multiply: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/0, + /*neg=*/false}, + // Note: testing for m=1 and n=1 is unnecessary, as this optimizes to + // dot(ct, ct) before DotOfGather optimization kicks in. + // Contract on rows: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/0, + /*neg=*/false}, + // Reverse matrix multiply: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/1, + /*neg=*/false}, + // Contract on columns: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/1, + /*neg=*/false}, + }; + std::vector all; + for (int i = 0; i < positives.size(); i++) { + DotOfGatherTestSpec positive_test = positives[i]; + all.push_back(positive_test); + DotOfGatherTestSpec negative_test = positive_test; + negative_test.neg = true; + all.push_back(negative_test); + } + return all; +} + +INSTANTIATE_TEST_CASE_P( + DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest, + ::testing::ValuesIn(DotOfGatherPositiveNegativeTests())); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index cf1231bcce4d004284b71a49063e3e470a9eb93f..95b4cb6d2e694063b648b264bd2454ae0a5469ff 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -101,7 +101,7 @@ StatusOr AllocationTracker::RegisterInternal( return result; } -tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { +Status AllocationTracker::Unregister(const GlobalDataHandle& data) { tensorflow::mutex_lock lock(mutex_); VLOG(2) << "Unregister(" << "handle: " << data.handle() << ")"; @@ -130,7 +130,7 @@ tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { for (auto& shaped_buffer : it->second) { shaped_buffer.reset(); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr> AllocationTracker::DeconstructTuple( @@ -220,8 +220,10 @@ void AllocationTracker::AddAllocationOrIncrementRefCount( AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal]; auto it = allocation_map.find(device_memory.opaque()); if (it == allocation_map.end()) { - allocation_map[device_memory.opaque()] = {device_memory, device_ordinal, - /*ref_count=*/1}; + allocation_map[device_memory.opaque()] = { + OwningDeviceMemory(device_memory, device_ordinal, + backend_->memory_allocator()), + /*ref_count=*/1}; } else { it->second.ref_count++; } @@ -235,13 +237,12 @@ Status AllocationTracker::DecrementRefCount(se::DeviceMemoryBase device_memory, Allocation& allocation = it->second; TF_RET_CHECK(allocation.ref_count >= 1); if (allocation.ref_count == 1) { - TF_RETURN_IF_ERROR(backend_->memory_allocator()->Deallocate( - device_ordinal, &device_memory)); + allocation.device_memory.Free(); allocation_map.erase(it); } else { allocation.ref_count--; } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index 1174fa641c06ae053bcc652416bfbc30cabc777c..a7d8927cf7e90d764ff8046df16c71922b11478e 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -76,10 +76,7 @@ class AllocationTracker { // Data structure encapsulating single memory allocation on the device. struct Allocation { // The pointer to this allocation. - se::DeviceMemoryBase device_memory; - - // The device that the memory is allocated on. - int device_ordinal; + OwningDeviceMemory device_memory; // This is the number of times this memory allocation is referred to by // registered data handles. @@ -126,7 +123,10 @@ class AllocationTracker { int64 next_handle_ GUARDED_BY(mutex_); // A map from device ordinal to AllocationMap. - tensorflow::gtl::FlatMap opaque_to_allocation_map_ + // + // This is not a TF FlatMap because (currently) FlatMap (and therefore + // AllocationMap) is not movable. + std::unordered_map opaque_to_allocation_map_ GUARDED_BY(mutex_); // A map from data handle to a vector of shaped buffers that represent the diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc new file mode 100644 index 0000000000000000000000000000000000000000..2099916509acdbc2680cc2b5bd405e96f2f7bfb8 --- /dev/null +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/batch_dot_simplification.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" + +namespace xla { +StatusOr +BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( + HloInstruction* batch_dot) { + const DotDimensionNumbers& dim_numbers = batch_dot->dot_dimension_numbers(); + HloInstruction *lhs = batch_dot->mutable_operand(0), + *rhs = batch_dot->mutable_operand(1); + const Shape& lhs_shape = lhs->shape(); + + std::vector degenerate_dims; + for (int64 batch_dim : dim_numbers.lhs_batch_dimensions()) { + if (lhs_shape.dimensions(batch_dim) == 1) { + degenerate_dims.push_back(batch_dim); + } + } + + if (degenerate_dims.empty()) { + return false; + } + + TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs, + ElideDegenerateDims(lhs, degenerate_dims)); + TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs, + ElideDegenerateDims(rhs, degenerate_dims)); + + DotDimensionNumbers new_dim_numbers = dim_numbers; + new_dim_numbers.clear_lhs_batch_dimensions(); + new_dim_numbers.clear_rhs_batch_dimensions(); + + for (int64 i = 0, e = dim_numbers.lhs_batch_dimensions_size() - + degenerate_dims.size(); + i < e; i++) { + new_dim_numbers.add_lhs_batch_dimensions(i); + new_dim_numbers.add_rhs_batch_dimensions(i); + } + + new_dim_numbers.set_lhs_contracting_dimensions( + 0, + new_dim_numbers.lhs_contracting_dimensions(0) - degenerate_dims.size()); + new_dim_numbers.set_rhs_contracting_dimensions( + 0, + new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size()); + + TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, + MakeDotHlo(new_lhs, new_rhs, new_dim_numbers)); + + TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped, + MakeReshapeHlo(batch_dot->shape(), new_dot)); + + VLOG(2) << "Replaced " << batch_dot->ToString() << " with " + << new_dot->ToString(); + + TF_RETURN_IF_ERROR( + batch_dot->parent()->ReplaceInstruction(batch_dot, new_dot_reshaped)); + + return true; +} + +tensorflow::StringPiece BatchDotSimplification::name() const { + return "batch-dot-simplification"; +} + +StatusOr BatchDotSimplification::Run(HloModule* module) { + bool changed = false; + std::vector dot_instrs; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + c_copy_if(computation->instructions(), std::back_inserter(dot_instrs), + [](HloInstruction* instr) { + return instr->opcode() == HloOpcode::kDot; + }); + } + for (HloInstruction* dot_instr : dot_instrs) { + TF_ASSIGN_OR_RETURN(bool elided_batch_dim_from_one, + ElideDegenerateBatchDimensionFromBatchDot(dot_instr)); + changed |= elided_batch_dim_from_one; + } + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h new file mode 100644 index 0000000000000000000000000000000000000000..c0ca8d8ebac1a3b218e7bd4d6db02b69cfb6916f --- /dev/null +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +// Simplifies batch dot operations. +// +// Normally these would live in the algebraic simplifier, but we want to run +// this to fixpoint (this pass reaches fixed point in one execution) before we +// run the DotDecomposer. +class BatchDotSimplification : public HloPassInterface { + public: + StatusOr Run(HloModule* module) override; + tensorflow::StringPiece name() const override; + + private: + StatusOr ElideDegenerateBatchDimensionFromBatchDot( + HloInstruction* batch_dot); +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..38f1a5d3a645f98220ec445bb9bbdf2b9b842109 --- /dev/null +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -0,0 +1,168 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/batch_dot_simplification.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class BatchDotSimplificationTest : public HloVerifiedTestBase {}; + +TEST_F(BatchDotSimplificationTest, + ElideSingleDegenerateBatchDotDim_VectorVector) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,3] parameter(0) + b = f32[1,3] parameter(1) + ROOT dot = f32[1] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0))); +} + +TEST_F(BatchDotSimplificationTest, + ElideSingleDegenerateBatchDotDim_MatrixVector) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,9,3] parameter(0) + b = f32[1,3] parameter(1) + ROOT dot = f32[1,9] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0))); +} + +TEST_F(BatchDotSimplificationTest, + ElideSingleDegenerateBatchDotDim_MatrixMatrix) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,9,3] parameter(0) + b = f32[1,3,7] parameter(1) + ROOT dot = f32[1,9,7] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0))); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDims_VectorVector) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[9,1,7,1,3] parameter(0) + b = f32[9,1,7,1,3] parameter(1) + ROOT dot = f32[9,1,7,1] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={4}, rhs_contracting_dims={4} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/2, /*rhs_contracting_dim=*/2))); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDims_VectorMatrix) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[9,1,7,1,3] parameter(0) + b = f32[9,1,7,1,20,3] parameter(1) + ROOT dot = f32[9,1,7,1,20] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={4}, rhs_contracting_dims={5} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/2, /*rhs_contracting_dim=*/3))); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDims_MatrixMatrix) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[9,1,7,1,19,3] parameter(0) + b = f32[9,1,7,1,3,20] parameter(1) + ROOT dot = f32[9,1,7,1,19,20] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={5}, rhs_contracting_dims={4} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/2))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 38086bd7e121847be6b6b69415cfe87814e7fc24..96e02b82b97ff2fd682638f4c6297cbc2019c481 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -15,35 +15,32 @@ limitations under the License. #include "tensorflow/compiler/xla/service/batchnorm_expander.h" -#include #include -#include -#include #include #include #include -#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { +namespace { + // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm // operations into smaller operations. class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { @@ -80,17 +77,25 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { rewrite_grad_op_(rewrite_grad_op), use_fusion_(use_fusion) {} - HloComputation* GetScalarBinaryComputation(PrimitiveType primitive_type, - HloOpcode opcode) { - HloComputation::Builder b("scalar_computation"); - auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(primitive_type, {}), "scalar_lhs")); - auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(primitive_type, {}), "scalar_rhs")); - auto scalar_op = b.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), - opcode, scalar_lhs, scalar_rhs)); - return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + HloComputation* GetOrCreateScalarAddComputation( + PrimitiveType primitive_type) { + HloComputation** scalar_add_computation = + &scalar_add_computations_[primitive_type]; + if (*scalar_add_computation) { + return *scalar_add_computation; + } + + HloComputation::Builder b("scalar_add_computation"); + Shape shape = ShapeUtil::MakeShape(primitive_type, {}); + auto scalar_lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, shape, "scalar_rhs")); + auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); + *scalar_add_computation = + computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + return *scalar_add_computation; } // Current HloComputation instance the BatchNormExpander is @@ -105,6 +110,10 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { // Whether rewrite has occurred. bool changed_ = false; + // Cached computations for adding two scalars. + tensorflow::gtl::FlatMap + scalar_add_computations_; + // Replaces the existing HLO instruction old_instruction, with // new_instruction, and marks the optimizer status as changed. // Returns the Status representing the result of the replace operation. @@ -129,6 +138,8 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { } }; +} // namespace + bool BatchNormExpanderVisitor::Run(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, @@ -199,7 +210,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); HloComputation* add_reduce_computation = - GetScalarBinaryComputation(ptype, HloOpcode::kAdd); + GetOrCreateScalarAddComputation(ptype); // X^2. auto operand_squared = add(HloInstruction::CreateBinary( @@ -500,7 +511,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( grad_output, activation_minus_mean)); HloComputation* add_reduce_computation = - GetScalarBinaryComputation(ptype, HloOpcode::kAdd); + GetOrCreateScalarAddComputation(ptype); // sum(Grad[Y] * (X - E[X])). auto sum_grad_output_times_activiation_minus_mean = diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 313910a861f7f4c0d1d60b738caef40e76cc4260..5e1499ee6b6ef397f95f7ed29e808d530777bd07 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -149,12 +149,12 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { EXPECT_TRUE(OutputsBF16(dot->operand(1))); EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( dot->operand(0)->literal(), - *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_a))); - LiteralTestUtil::ExpectEqual( + *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)))); + EXPECT_TRUE(LiteralTestUtil::Equal( dot->operand(1)->literal(), - *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_b))); + *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)))); } // Tests that BF16 can be propagated through nested tuples. diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 94ccfedf6289b4af1accebd358671c3e2bc10ba7..c0b8bf903923a327fb1378eafb51a7d493d5e62d 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -699,7 +700,7 @@ BufferAssignmentProto BufferAssignment::ToProto() const { BufferAssignmentProto::BufferAlias* proto_alias = proto.add_buffer_aliases(); LogicalBufferProto::Location proto_alias_location = - LogicalBuffer::ToLocationProto(*alias.instruction(), alias.index()); + BufferValue::ToLocationProto(*alias.instruction(), alias.index()); proto_alias->set_source_buffer_id(buffer.id()); proto_alias->mutable_location()->Swap(&proto_alias_location); } @@ -1083,7 +1084,9 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( VLOG(2) << "Simulating heap for color " << color; int64 alignment = assignment->color_alignment_(color); HeapSimulator::Options options; - options.buffers_to_assign = &single_colored_set.second; + BufferValueFlatSet buffer_value_set = + ToBufferValueFlatSet(single_colored_set.second); + options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( @@ -1111,7 +1114,9 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( VLOG(2) << "Simulating heap for color " << color; int64 alignment = assignment->color_alignment_(color); HeapSimulator::Options options; - options.buffers_to_assign = &single_colored_set.second; + BufferValueFlatSet buffer_value_set = + ToBufferValueFlatSet(single_colored_set.second); + options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( @@ -1224,7 +1229,10 @@ void BufferAssigner::AssignBuffersFromHeapSimulator( BufferAllocation* allocation = assignment->NewEmptyAllocation( result.heap_size, /*is_thread_local=*/false, /*is_reusable=*/true, color); for (const auto& buffer_chunk : result.chunk_map) { - const LogicalBuffer& buffer = *buffer_chunk.first; + // TODO(lauj) Remove this down_cast after downstream users of + // BufferAllocation::assigned_buffers() are updated to use BufferValue. + const LogicalBuffer& buffer = + *CHECK_NOTNULL(dynamic_cast(buffer_chunk.first)); const HeapSimulator::Chunk& chunk = buffer_chunk.second; assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 15fd905e8d593994c1cd5ec77cef6db7c2dbefdb..ad0b0bf7c25d7194a06801e4ef1c9ee961f6b915 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -415,10 +415,10 @@ class BufferAssignment { // Only BufferAssigner can build or modify BufferAssignments. friend class BufferAssigner; - explicit BufferAssignment(const HloModule* module, - std::unique_ptr liveness, - LogicalBuffer::SizeFunction buffer_size, - LogicalBuffer::AlignmentFunction color_alignment) + BufferAssignment(const HloModule* module, + std::unique_ptr liveness, + LogicalBuffer::SizeFunction buffer_size, + LogicalBuffer::AlignmentFunction color_alignment) : module_(module), liveness_(std::move(liveness)), buffer_size_(std::move(buffer_size)), diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 37982aaef9eddd64ef6b57ad5a9cf8dd6a565097..810d597e730c1823668c81598df6138655e58b55 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -44,7 +43,7 @@ StatusOr> BufferLiveness::Run( return std::move(liveness); } -tensorflow::Status BufferLiveness::Analyze() { +Status BufferLiveness::Analyze() { TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_)); for (auto* computation : module_->computations()) { if (computation->IsFusionComputation()) { @@ -71,7 +70,7 @@ tensorflow::Status BufferLiveness::Analyze() { } XLA_VLOG_LINES(3, ToString()); - return tensorflow::Status::OK(); + return Status::OK(); } string BufferLiveness::ToString() const { @@ -105,8 +104,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { // Every user of 'a' must be a predecessor of 'b' or 'b' itself. for (auto user : alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), user, - points_to_analysis())) { + if (points_to_analysis().DoesNotUseOperandBuffer(alias.instruction(), + alias.index(), user)) { continue; } if (user != b.instruction() && @@ -132,9 +131,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, // the qualifications specified in CanShareOperandBufferWithUser. for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { if (b.instruction()->IsUserOf(alias.instruction()) && - !CanShareOperandBufferWithUser(alias.instruction(), alias.index(), - b.instruction(), b.index(), - points_to_analysis())) { + !points_to_analysis().CanShareOperandBufferWithUser( + alias.instruction(), alias.index(), b.instruction(), b.index())) { return false; } } diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h index 11834a5127e383cc2ec2ab3fe1bb82ba86e4abed..cdd3cf4032ef6916086e1c2d148b575192503000 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.h +++ b/tensorflow/compiler/xla/service/buffer_liveness.h @@ -89,7 +89,7 @@ class BufferLiveness { // Perform buffer liveness analysis. This method must be called prior to // MayInterfere or MaybeLiveOut. - tensorflow::Status Analyze(); + Status Analyze(); // Returns true if the live range of the buffer of 'a' is strictly before the // live range of the buffer of 'b' (they do not overlap). diff --git a/tensorflow/compiler/xla/service/buffer_value_containers.h b/tensorflow/compiler/xla/service/buffer_value_containers.h new file mode 100644 index 0000000000000000000000000000000000000000..305914fca828f110bf54239bddb1590172562b16 --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_value_containers.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ + +#include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/core/lib/gtl/compactptrset.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// Define various containers of BufferValues, and utilities to convert from +// containers of LogicalBuffers to containers of BufferValues. + +using BufferValueCompactPointerSet = + tensorflow::gtl::CompactPointerSet; +template +BufferValueCompactPointerSet ToBufferValueCompactPointerSet( + const LogicalBufferContainerT& logical_buffer_container) { + BufferValueCompactPointerSet output; + for (const LogicalBuffer* buffer : logical_buffer_container) { + output.insert(buffer); + } + return output; +} + +using BufferValueFlatSet = tensorflow::gtl::FlatSet; +template +BufferValueFlatSet ToBufferValueFlatSet( + const LogicalBufferContainerT& logical_buffer_container) { + BufferValueFlatSet output; + output.reserve(logical_buffer_container.size()); + for (const LogicalBuffer* buffer : logical_buffer_container) { + output.insert(buffer); + } + return output; +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h index c10609e67fcdec459baf25a95173bbf700994be9..7f2ce0e8974c01b09664235d7b9d19555b2705a3 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.h +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -75,48 +75,42 @@ class CompileOnlyService : public Service { // Override Service methods that require or imply the existence of an // execute backend. Note that this does not include TransferToClient, as // computing constants produces global data that we may wish to transfer. - tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) override { + Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); } - tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override { + Status ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); } - tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override { + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override { return Unimplemented("CompileOnlyService does not support devices."); } - tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override { + Status ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); } - tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) override { + Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); } - tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, - TransferToServerResponse* result) override { + Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override { + Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override { + Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override { + Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override { return Unimplemented("CompileOnlyService does not support devices."); } diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc index d2d4f14fcec35f5b51a2670a646154ce8bb9bfc1..cb61f3da39fb8eef69fd81066d87a1da91a62935 100644 --- a/tensorflow/compiler/xla/service/computation_layout.cc +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -23,12 +23,15 @@ limitations under the License. namespace xla { -ComputationLayout::ComputationLayout(const ProgramShape& program_shape) +ComputationLayout::ComputationLayout(const ProgramShape& program_shape, + bool ignore_layouts) : result_layout_(program_shape.result()) { for (auto& shape : program_shape.parameters()) { parameter_layouts_.emplace_back(shape); } - SetToDefaultLayout(); + if (ignore_layouts) { + SetToDefaultLayout(); + } } void ComputationLayout::SetToDefaultLayout() { diff --git a/tensorflow/compiler/xla/service/computation_layout.h b/tensorflow/compiler/xla/service/computation_layout.h index 80e102411c7885669947d89f378b1ec61e3e4e96..53c3a3f7b738687db3098acfaef1ae87860d0440 100644 --- a/tensorflow/compiler/xla/service/computation_layout.h +++ b/tensorflow/compiler/xla/service/computation_layout.h @@ -34,8 +34,9 @@ class ComputationLayout { public: // Constructs a ComputationLayout from a ProgramShape. The layouts of the // parameters and results are set to the default layout. Layouts in the - // ProgramShape are ignored. - explicit ComputationLayout(const ProgramShape& program_shape); + // ProgramShape are ignored if ignore_layouts is true. + explicit ComputationLayout(const ProgramShape& program_shape, + bool ignore_layouts = true); // Returns the layout of a particular parameter. const ShapeLayout& parameter_layout(int64 param_no) const { diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index cbe2ba2e50ab213133196987cf486152edc9d785..33d8338809d4e8c7c4774f062c3dda5494543ca6 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 7e6d58c7fa5ccaf3e0a6f21d43a54906a3fbe408..bfd85f257fb9550a6babb2459a7227ca9003a14f 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -103,6 +103,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:batch_dot_simplification", "//tensorflow/compiler/xla/service:batchnorm_expander", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", @@ -125,6 +126,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_scheduling", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:indexed_array_analysis", "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", @@ -176,6 +178,7 @@ cc_library( ":runtime_matmul", ":runtime_matmul_mkl", ":runtime_single_threaded_conv2d", + ":runtime_single_threaded_fft", ":runtime_single_threaded_matmul", "@llvm//:execution_engine", "@llvm//:core", @@ -295,6 +298,15 @@ cc_library( ], ) +cc_library( + name = "target_machine_features_fake", + testonly = 1, + hdrs = ["target_machine_features_fake.h"], + deps = [ + ":target_machine_features", + ], +) + cc_library( name = "ir_function", srcs = ["ir_function.cc"], @@ -336,6 +348,7 @@ cc_library( deps = [ ":cpu_options", ":cpu_runtime", + ":ir_emission_utils", ":target_machine_features", ":vector_support_library", "//tensorflow/compiler/xla:shape_util", @@ -408,7 +421,6 @@ cc_library( "//tensorflow/core:lib", "@llvm//:analysis", "@llvm//:core", - "@llvm//:execution_engine", "@llvm//:ipo", "@llvm//:mc", "@llvm//:object", @@ -505,7 +517,6 @@ cc_library( deps = [ "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:framework", "//tensorflow/core:framework_lite", "//third_party/eigen3", ], @@ -567,6 +578,22 @@ cc_library( ], ) +cc_library( + name = "runtime_single_threaded_fft", + srcs = [ + "runtime_fft_impl.h", + "runtime_single_threaded_fft.cc", + ], + hdrs = ["runtime_single_threaded_fft.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + cc_library( name = "runtime_single_threaded_matmul", srcs = ["runtime_single_threaded_matmul.cc"], @@ -660,6 +687,7 @@ cc_library( hdrs = ["ir_emission_utils.h"], deps = [ ":cpu_runtime", + ":target_machine_features", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/service:hlo", @@ -672,6 +700,7 @@ tf_cc_test( srcs = ["ir_emission_utils_test.cc"], deps = [ ":ir_emission_utils", + ":target_machine_features_fake", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", @@ -690,6 +719,7 @@ cc_library( deps = [ ":dot_op_emitter", ":ir_emission_utils", + ":target_machine_features", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:layout_assignment", @@ -703,6 +733,7 @@ tf_cc_test( srcs = ["cpu_layout_assignment_test.cc"], deps = [ ":cpu_layout_assignment", + ":target_machine_features_fake", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -727,6 +758,7 @@ cc_library( deps = [ ":cpu_runtime", ":ir_emission_utils", + ":target_machine_features", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -741,6 +773,7 @@ tf_cc_test( srcs = ["conv_canonicalization_test.cc"], deps = [ ":conv_canonicalization", + ":target_machine_features_fake", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", @@ -779,6 +812,7 @@ cc_library( ":dot_op_emitter", ":ir_emission_utils", ":shape_partition", + ":target_machine_features", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", @@ -791,6 +825,7 @@ tf_cc_test( deps = [ ":cpu_executable", ":parallel_task_assignment", + ":target_machine_features_fake", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -913,3 +948,17 @@ tf_cc_test( "//tensorflow/core:test", ], ) + +tf_cc_test( + name = "cpu_eigen_tensor_alignment_test", + size = "small", + srcs = ["cpu_eigen_tensor_alignment_test.cc"], + deps = [ + ":dot_op_emitter", + ":ir_emission_utils", + ":target_machine_features_fake", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 2136aeb3877685373efaf5bf702a42b39a63f082..0985b9297fe487f3523826cb0978c17775549735 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -33,7 +33,8 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { for (HloInstruction* hlo : module->entry_computation()->MakeInstructionPostOrder()) { if (hlo->opcode() == HloOpcode::kConvolution && - !PotentiallyImplementedAsEigenConvolution(*hlo)) { + !PotentiallyImplementedAsEigenConvolution(*hlo, + target_machine_features_)) { const ConvolutionDimensionNumbers& dnums = hlo->convolution_dimension_numbers(); auto input_batch_dim = dnums.input_batch_dimension(); diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h index 9b2c3d82eb673ce542cc03ec706015967dc975b6..e6fd1499edd0095395194200a5b444ad61e7e39d 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_ +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -32,12 +33,19 @@ namespace cpu { // convolutions can run faster. class ConvCanonicalization : public HloPassInterface { public: + explicit ConvCanonicalization( + const TargetMachineFeatures* target_machine_features) + : target_machine_features_(*target_machine_features) {} + ~ConvCanonicalization() override {} tensorflow::StringPiece name() const override { return "convolution-canonicalization"; } StatusOr Run(HloModule* module) override; + + private: + const TargetMachineFeatures& target_machine_features_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 968f53d5c706651d2a470a853e0e9b601c0ed2df..375b017b09263c20c1b1ef8329f7e2f6a573dda4 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -89,7 +90,11 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - ConvCanonicalization conv_canonicalization; + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + ConvCanonicalization conv_canonicalization(&target_machine_features); EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie()); const HloInstruction* output_reshape = entry_computation->root_instruction(); @@ -146,7 +151,11 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - ConvCanonicalization conv_canonicalization; + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + ConvCanonicalization conv_canonicalization(&target_machine_features); EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie()); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 91ed6e427ac7c20461cde91ef0cfdf13f4b55992..25b18eff20f901fc34343a12bfbd353ecec49cfb 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/batch_dot_simplification.h" #include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" @@ -81,6 +82,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/indexed_array_analysis.h" #include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" @@ -231,7 +233,10 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { }; } // namespace -Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { +Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, + llvm::TargetMachine* target_machine) { + LLVMTargetMachineFeatures target_machine_features(target_machine); + // Optimization pipeline. HloPassPipeline pipeline("CPU"); pipeline.AddInvariantChecker(); @@ -248,8 +253,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner // pass. pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(&target_machine_features); { auto& pass = pipeline.AddPass>("simplification"); @@ -278,10 +284,12 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pass.AddPass(); pass.AddPass(); } + pipeline.AddPass(); pipeline.AddPass( - [](const HloInstruction& dot, - const TransposeFolding::OperandIndices& candidate_operands) { - return PotentiallyImplementedAsEigenDot(dot) + [&target_machine_features]( + const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return PotentiallyImplementedAsEigenDot(dot, target_machine_features) ? candidate_operands : TransposeFolding::OperandIndices{}; }, @@ -296,7 +304,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass( - module->device_entry_computation_layout()); + module->mutable_device_entry_computation_layout(), + &target_machine_features); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. pipeline.AddPass>( @@ -316,8 +325,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // and thread synchronization dependencies which would likely increase // binary size (and most AOT applications are single-threaded). // TODO(b/29630486) Support multi-threaded AOT. - pipeline.AddPass(max_parallelism, - ShapeSizeBytesFunction()); + pipeline.AddPass( + max_parallelism, ShapeSizeBytesFunction(), &target_machine_features); } // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -470,7 +479,13 @@ StatusOr> CpuCompiler::RunHloPasses( VLOG(2) << "Before optimization:"; XLA_VLOG_LINES(2, module->ToString()); - TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false)); + std::unique_ptr jit_target_machine = + SimpleOrcJIT::InferTargetMachineForJIT( + CompilerTargetOptions(module->config()), + CodeGenOptLevel(module->config())); + + TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false, + jit_target_machine.get())); VLOG(2) << "After optimization:"; XLA_VLOG_LINES(2, module->ToString()); @@ -535,7 +550,8 @@ StatusOr> CpuCompiler::RunBackend( // and reduced memory usage (as compared to using DependencyHloOrdering). TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction())); + CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction(), + DFSMemoryScheduler)); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. @@ -560,10 +576,11 @@ StatusOr> CpuCompiler::RunBackend( // GetEmbeddedComputations guarantees that a called computation occurs // before a caller computation. + LLVMTargetMachineFeatures target_machine_features(jit->target_machine()); IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), - jit->target_machine(), jit->external_constant_pool()); + &target_machine_features, jit->external_constant_pool()); for (auto embedded_computation : entry_computation->MakeEmbeddedComputationsList()) { @@ -705,7 +722,8 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, VLOG(2) << "Before optimization:"; XLA_VLOG_LINES(2, module->ToString()); - TF_RETURN_IF_ERROR(RunHloPasses(module, /*is_aot_compile=*/true)); + TF_RETURN_IF_ERROR( + RunHloPasses(module, /*is_aot_compile=*/true, target_machine.get())); VLOG(2) << "After optimization:"; XLA_VLOG_LINES(2, module->ToString()); @@ -745,10 +763,11 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, &hlo_profile_index_map, &hlo_profile_printer_data)); } + LLVMTargetMachineFeatures target_machine_features(target_machine.get()); IrEmitter ir_emitter(*module, *assignment, &llvm_module, std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), - target_machine.get(), + &target_machine_features, /*external_constant_pool=*/nullptr); HloComputation* computation = module->entry_computation(); for (auto embedded_computation : diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index 65b05f04fa8d9c72e7bfb6978f6a6384dfbcf976..e56f9f01134f84b4698c078b750b0c1fdca7748e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" @@ -148,7 +149,8 @@ class CpuCompiler : public LLVMCompiler { // Runs the HLO passes which are necessary for both optimizations and // correctness. - Status RunHloPasses(HloModule* module, bool is_aot_compile); + Status RunHloPasses(HloModule* module, bool is_aot_compile, + llvm::TargetMachine* target_machine); TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d12fa6bb9ad2054bdc052c9d7b3729cc28e11f6d --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace cpu { +namespace { + +// Test that we don't call into Eigen with tensors too small to be aligned +// reliably. + +class CpuEigenTensorAlignmentTest : public ::testing::Test {}; + +TEST_F(CpuEigenTensorAlignmentTest, EigenDotAlignment) { + string hlo_string = R"( +HloModule DotOperation + +ENTRY DotOperation { + arg0 = f32[5,256] parameter(0) + arg1 = f32[256,1024] parameter(1) + ROOT dot = f32[5,1024] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + HloInstruction* dot = module->entry_computation()->root_instruction(); + + TargetMachineFeaturesWithFakeAlignmentLogic target_machine_with_no_alignment( + [](int64 size) { return 1; }); + + EXPECT_FALSE( + PotentiallyImplementedAsEigenDot(*dot, target_machine_with_no_alignment)); + + TargetMachineFeaturesWithFakeAlignmentLogic + target_machine_with_full_alignment([](int64 size) { + return TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + + EXPECT_TRUE(PotentiallyImplementedAsEigenDot( + *dot, target_machine_with_full_alignment)); +} + +TEST_F(CpuEigenTensorAlignmentTest, EigenConvAlignment) { + string hlo_string = R"( +HloModule ConvOperation + +ENTRY ConvOperation { + arg0 = f32[1,2,1] parameter(0) + arg1 = f32[1,1,1] parameter(1) + ROOT conv = f32[1,2,1] convolution(arg0, arg1), window={size=1}, dim_labels=b0f_0io->b0f +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + HloInstruction* conv = module->entry_computation()->root_instruction(); + + TargetMachineFeaturesWithFakeAlignmentLogic target_machine_with_no_alignment( + [](int64 size) { return 1; }); + + EXPECT_FALSE(PotentiallyImplementedAsEigenConvolution( + *conv, target_machine_with_no_alignment)); + + TargetMachineFeaturesWithFakeAlignmentLogic + target_machine_with_full_alignment([](int64 size) { + return TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + + EXPECT_TRUE(PotentiallyImplementedAsEigenConvolution( + *conv, target_machine_with_full_alignment)); +} +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 32613b869078305edda97c11ac250f67de32b805..cf43b74c699ca8cbbef11a0abbaf4d69476f5d77 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -73,7 +73,7 @@ CpuExecutable::CpuExecutable( Status CpuExecutable::AllocateBuffers( DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector* buffers) { + std::vector* buffers) { CHECK_EQ(buffers->size(), assignment_->Allocations().size()); VLOG(3) << "Allocating " << assignment_->Allocations().size() << " allocations for module " << module().name(); @@ -201,60 +201,18 @@ Status CpuExecutable::ExecuteComputeFunction( return Status::OK(); } -static void LogLiveAddresses( - tensorflow::gtl::ArraySlice buffers, - const std::vector& buffers_in_result) { - if (!VLOG_IS_ON(3)) { - return; - } - - CHECK_EQ(buffers.size(), buffers_in_result.size()); - std::vector live_out_buffers; - for (int i = 0; i < buffers.size(); ++i) { - if (buffers_in_result[i]) { - live_out_buffers.push_back(buffers[i].opaque()); - } - } - VLOG(3) << "Live addresses in output marking found " - << live_out_buffers.size() << " addresses:\n" - << tensorflow::str_util::Join( - live_out_buffers, ", ", [](string* out, const void* address) { - tensorflow::strings::StrAppend( - out, tensorflow::strings::Printf("%p", address)); - }); -} - -static Status DeallocateTempBuffers( - DeviceMemoryAllocator* allocator, se::Stream* stream, - tensorflow::gtl::ArraySlice buffers, - const std::vector& buffers_in_result) { - // Keep those buffers in the output of the marked live because they are needed - // by the service. They will be deallocated by the service. - for (size_t i = 0; i < buffers.size(); ++i) { - se::DeviceMemoryBase alloc = buffers[i]; - if (!buffers_in_result[i] && !alloc.is_null()) { - VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" - << alloc.opaque() << "]"; - TF_RETURN_IF_ERROR( - allocator->Deallocate(stream->parent()->device_ordinal(), &alloc)); - } - } - - return Status::OK(); -} - StatusOr CpuExecutable::CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice allocated_buffers, - std::vector* buffers_in_result) { + tensorflow::gtl::MutableArraySlice buffers) { se::Stream* stream = run_options->stream(); ScopedShapedBuffer result_buffer( /*on_host_shape=*/host_result_shape(), /*on_device_shape=*/host_result_shape(), run_options->allocator(), stream->parent()->device_ordinal()); - // Copy DeviceMemoryBase values which contain the array(s) of the result into - // the respective location in ShapedBuffer which is returned to the caller. + // Move OwningDeviceMemory values which contain the array(s) of the result + // into the respective location in ScopedShapedBuffer which is returned to the + // caller. TF_RETURN_IF_ERROR(result_buffer.buffers().ForEachMutableElementWithStatus( [&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { const auto& sources = this->GetRootPointsToSet().element(index); @@ -273,10 +231,9 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( CHECK(!slice.allocation()->is_entry_computation_parameter()); const BufferAllocation::Index buffer_index = slice.index(); - const se::DeviceMemoryBase& buffer = allocated_buffers[buffer_index]; + OwningDeviceMemory& buffer = buffers[buffer_index]; CHECK(!buffer.is_null() || buffer.size() == 0); - *device_memory = buffer; - (*buffers_in_result)[buffer_index] = true; + *device_memory = buffer.Forget(); return Status::OK(); })); return std::move(result_buffer); @@ -292,23 +249,21 @@ StatusOr CpuExecutable::ExecuteOnStream( se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); + std::vector buffers(assignment_->Allocations().size()); TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); - TF_RETURN_IF_ERROR(ExecuteComputeFunction( - &run_options->run_options(), arguments, buffers, hlo_execution_profile)); - std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer result_buffer, - CreateResultShapedBuffer(run_options, buffers, &buffers_in_result)); - - // Free all buffers not in the result. - TF_RETURN_IF_ERROR(DeallocateTempBuffers(memory_allocator, stream, buffers, - buffers_in_result)); + std::vector unowning_buffers; + unowning_buffers.reserve(buffers.size()); + for (auto& buffer : buffers) { + unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); + } + TF_RETURN_IF_ERROR(ExecuteComputeFunction(&run_options->run_options(), + arguments, unowning_buffers, + hlo_execution_profile)); - return std::move(result_buffer); + return CreateResultShapedBuffer(run_options, &buffers); } StatusOr CpuExecutable::ExecuteAsyncOnStream( @@ -324,30 +279,53 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( run_options->stream()->implementation()); se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); - + std::vector buffers(assignment_->Allocations().size()); TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); - std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer result_buffer, - CreateResultShapedBuffer(run_options, buffers, &buffers_in_result)); - - LogLiveAddresses(buffers, buffers_in_result); - - host_stream->EnqueueTask([this, run_options, arguments, buffers, - buffers_in_result, memory_allocator, stream]() { - // Failing a CHECK here is not great, but I don't see an obvious way to - // return a failed Status asynchronously. - TF_CHECK_OK(ExecuteComputeFunction(&run_options->run_options(), arguments, - buffers, - /*hlo_execution_profile=*/nullptr)); - TF_CHECK_OK(DeallocateTempBuffers(memory_allocator, stream, buffers, - buffers_in_result)); - }); + std::vector unowning_buffers; + unowning_buffers.reserve(buffers.size()); + for (auto& buffer : buffers) { + unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); + } + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + CreateResultShapedBuffer(run_options, &buffers)); - return std::move(result_buffer); + // At this point, `unowning_buffers` contains unowning pointers to all of our + // buffers, and `buffers` contains owning pointers to the non-live-out + // buffers. Enqueue a task which keeps alive the non-live-out buffers. + // + // Logically we want this lambda to capture `buffers` by move, ultimately our + // functor needs to be wrapped in an std::function, and that requires its + // functor to be copyable. Thus we perpitrate the hack of capturing buffers + // "by shared pointer". + // + // We also need to change the types of some of the variables we capture: + // run_options needs to change from a pointer to a value type, and arguments + // needs to change from an ArraySlice into a vector. We use a struct instead + // of a lambda to make this explicit. + struct AsyncRunTask { + CpuExecutable* executable; + ServiceExecutableRunOptions run_options; + std::vector arguments; + std::vector unowning_buffers; + std::shared_ptr> buffers; + + void operator()() { + // Failing a CHECK here is not great, but I don't see an obvious way to + // return a failed Status asynchronously. + TF_CHECK_OK(executable->ExecuteComputeFunction( + &run_options.run_options(), arguments, unowning_buffers, + /*hlo_execution_profile=*/nullptr)); + } + }; + host_stream->EnqueueTask(AsyncRunTask{ + this, *run_options, + std::vector(arguments.begin(), arguments.end()), + unowning_buffers, + std::make_shared>(std::move(buffers))}); + + return std::move(result); } /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 68ad38cba88720a04519fc2473fe6f9decbaaf93..8dd47bfb865e8a0552542f510d3365cff0d111e0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -92,7 +92,7 @@ class CpuExecutable : public Executable { // buffer is assigned for this element. Status AllocateBuffers(DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector* buffers); + std::vector* buffers); // Calls the generated function performing the computation with the given // arguments using the supplied buffers. @@ -102,16 +102,12 @@ class CpuExecutable : public Executable { tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile); - // Creates a ScopedShapedBuffer for holding the result of the computation. The - // addresses (DeviceMemoryBases) are set according to buffer assignment. - // 'buffers_in_result' should point to a vector of the same size as - // 'allocated_buffers'. An element in buffers_in_result is set to true if the - // corresponding buffer is live out of the computation (and thus contained in - // the returned ShapedBuffer). + // Creates a ScopedShapedBuffer for holding the result of the computation, + // moving buffers out of allocated_buffers and into the result as appropriate. + // The addresses are set according to buffer assignment. StatusOr CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice allocated_buffers, - std::vector* buffers_in_result); + tensorflow::gtl::MutableArraySlice buffers); // Returns the points-to set of the root instruction of the entry // computation. Uses points-to analysis from buffer assignment. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index a98e85a151ffb77e6682b82164603481265283c4..46fe060817b0264d90574b45a94cf1f6e5964593 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -158,37 +158,95 @@ TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) { EXPECT_EQ(dot, computation->root_instruction()); } -TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion) { - HloComputation::Builder builder(TestName()); - HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 256}), "arg0")); - HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1024, 256}), "arg1")); +TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_RHS) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion - HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg1)); - HloInstruction* transpose1 = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(S32, {256, 1024}), exp1, {1, 0})); - builder.AddInstruction( - MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, transpose1)); +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[1,256] parameter(0) + arg1 = f32[1024,256] parameter(1) + exponential = s32[1024,256] exponential(arg1) + transpose = s32[256,1024] transpose(exponential), dimensions={1,0} + ROOT dot = f32[1,1024] dot(arg0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + HloComputation* computation = module->entry_computation(); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); TransposeFolding transpose_folding( [](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { return candidate_operands; }, TransposeFolding::NeverFoldTranspose); - EXPECT_TRUE(transpose_folding.Run(module.get()).ValueOrDie()); - EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kFusion); - EXPECT_EQ(computation->root_instruction()->fusion_kind(), - HloInstruction::FusionKind::kTransposeDot); - EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); - EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kFusion); - EXPECT_EQ(computation->root_instruction()->fusion_kind(), - HloInstruction::FusionKind::kTransposeDot); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + ASSERT_TRUE(changed); + ASSERT_THAT(computation->root_instruction(), + op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)); +} + +TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_LHS) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion + +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[256,1] parameter(0) + arg1 = f32[256,1024] parameter(1) + transpose = s32[1,256] transpose(arg0), dimensions={1,0} + exponential = s32[256,1024] exponential(arg1) + ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + HloComputation* computation = module->entry_computation(); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + TransposeFolding::NeverFoldTranspose); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + ASSERT_TRUE(changed); + ASSERT_THAT(computation->root_instruction(), + op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)), + /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0)); +} + +TEST_F(InstructionFusionTest, + DotOperationFusion_TransposeFusion_LHS_NonDefault) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion + +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[1,256] parameter(0) + arg1 = f32[256,1024] parameter(1) + transpose = s32[256,1] transpose(arg0), dimensions={1,0} + exponential = s32[256,1024] exponential(arg1) + ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={0}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + HloComputation* computation = module->entry_computation(); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + TransposeFolding::NeverFoldTranspose); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + ASSERT_TRUE(changed); + ASSERT_THAT(computation->root_instruction(), + op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0)); } class OpcodeFusionTest : public InstructionFusionTest { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index e8117377e61a4e21b8c45b929c518a18878fcb60..aa872d5ec9e7593b8d2f731421c17af590729529 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -100,7 +100,8 @@ Status CpuLayoutAssignment::AddBackendConstraints( const HloComputation* computation = constraints->computation(); for (auto* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction)) { + PotentiallyImplementedAsEigenConvolution(*instruction, + target_machine_features_)) { const HloInstruction* convolution = instruction; const HloInstruction* lhs_instruction = convolution->operand(0); const HloInstruction* rhs_instruction = convolution->operand(1); @@ -126,7 +127,8 @@ Status CpuLayoutAssignment::AddBackendConstraints( const HloInstruction* op = instruction->operand(*op_idx); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( ColMajorShape(op->shape()), instruction, *op_idx)); - } else if (PotentiallyImplementedAsEigenDot(*instruction)) { + } else if (PotentiallyImplementedAsEigenDot(*instruction, + target_machine_features_)) { const HloInstruction* dot = instruction; // In order to implement `dot` with Eigen dot, the layouts of the lhs, // rhs, and output need to be row-major. @@ -139,13 +141,9 @@ Status CpuLayoutAssignment::AddBackendConstraints( Shape lhs_shape(RowMajorShape(lhs_instruction->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0)); - // dot is a kDot or a kTransposeDot fusion node. In the latter case, if - // it represents X @ X, it may have just one operand. - if (dot->operand_count() > 1) { - const HloInstruction* rhs_instruction = dot->operand(1); - Shape rhs_shape(RowMajorShape(rhs_instruction->shape())); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); - } + const HloInstruction* rhs_instruction = dot->operand(1); + Shape rhs_shape(RowMajorShape(rhs_instruction->shape())); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); // Set layouts of the instructions' shapes. TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(output_shape, dot)); @@ -181,7 +179,7 @@ Status CpuLayoutAssignment::AddBackendConstraints( } } } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h index 09adb5cb02abba5844a1740bdb50a578e1bdf8b5..3c4fe68b830d9602f009b318d4e51e9a04a27e09 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/core/lib/core/status.h" @@ -28,12 +29,16 @@ namespace cpu { class CpuLayoutAssignment : public LayoutAssignment { public: explicit CpuLayoutAssignment( - const ComputationLayout& entry_computation_layout) - : LayoutAssignment(entry_computation_layout) {} + ComputationLayout* entry_computation_layout, + const TargetMachineFeatures* target_machine_features) + : LayoutAssignment(entry_computation_layout), + target_machine_features_(*target_machine_features) {} ~CpuLayoutAssignment() override {} protected: Status AddBackendConstraints(LayoutConstraints* constraints) override; + + const TargetMachineFeatures& target_machine_features_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index ba4c5a23d3e043fd6680c2f9abc2275696737ee7..429fc7b78608da0e9cd794ac294851b326f5be24 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -49,7 +50,12 @@ class CpuLayoutAssignmentTest : public HloTestBase { protected: void AssignLayouts(HloModule* module, ComputationLayout* entry_computation_layout) { - cpu::CpuLayoutAssignment layout_assignment(*entry_computation_layout); + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout, + &target_machine_features); EXPECT_IS_OK(layout_assignment.Run(module).status()); } }; @@ -311,7 +317,12 @@ static StatusOr RunDotOutputFusion( result.addend_fusion_param = fusion_instruction->operand( fused_add->operand(1 - dot_operand_idx_in_add)->parameter_number()); - cpu::CpuLayoutAssignment layout_assignment(computation_layout); + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + cpu::CpuLayoutAssignment layout_assignment(&computation_layout, + &target_machine_features); TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something, layout_assignment.Run(module)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index f9c51f243c47b8069500eca3c9c2929b17f04e62..e75fcb6bc9719f7453d5f0cb52d1673cef1fd3df 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -22,6 +22,8 @@ namespace { const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; +const char* const kXlaEnableExperimentalLlvmIrGemm = + "xla_enable_experimental_llvm_ir_gemm"; } // namespace @@ -54,6 +56,12 @@ tensorflow::gtl::optional LlvmIrGemvTilingFactor( return tensorflow::gtl::nullopt; } +bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0; +} + } // namespace options } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index be62ff3cc1af23408ca8a00f1372e7a998f160c6..106dfbbc62dfba8d3de74e0a2ae3bb247bd91d67 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -26,6 +26,7 @@ namespace options { bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); +bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config); tensorflow::gtl::optional LlvmIrGemvTilingFactor( const HloModuleConfig& config); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 215405f6802cf1956ebec011da2fcd11b95c0c64..54c52bc08f9c53b8c6898689b18c4cb7f4bdcfd0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -51,6 +51,8 @@ extern const char* const kEigenConvF16SymbolName = extern const char* const kEigenConvF32SymbolName = "__xla_cpu_runtime_EigenConvF32"; extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft"; +extern const char* const kEigenSingleThreadedFftSymbolName = + "__xla_cpu_runtime_EigenSingleThreadedFft"; extern const char* const kEigenSingleThreadedMatMulF16SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF16"; extern const char* const kEigenSingleThreadedMatMulF32SymbolName = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 1dce6efa5cd65e67ae73a2e2affe2d2d3c537508..aa0e96712302e806a389c6ad05a2c1b6634ef901 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -52,6 +52,7 @@ extern const char* const kMKLSingleThreadedMatMulF64SymbolName; extern const char* const kEigenConvF16SymbolName; extern const char* const kEigenConvF32SymbolName; extern const char* const kEigenFftSymbolName; +extern const char* const kEigenSingleThreadedFftSymbolName; extern const char* const kEigenSingleThreadedMatMulF16SymbolName; extern const char* const kEigenSingleThreadedMatMulF32SymbolName; extern const char* const kEigenSingleThreadedMatMulF64SymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 9b39e7f5765ae5eb6a25c06eef4d74b1c00e5c91..d97802ee45d6add3c466577d7624d9ca74e2f380 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -88,8 +88,8 @@ CpuTransferManager::CpuTransferManager() : GenericTransferManager(se::host::kHostPlatformId, /*pointer_size=*/sizeof(void*)) {} -Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) { +Status CpuTransferManager::TransferLiteralToInfeed( + se::StreamExecutor* executor, const LiteralSlice& literal) { const Shape& shape = literal.shape(); VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h index 3ecb0d236498371f48caf63249f9cd4e8777752b..6dfc666f09dfa6df740cd54bea0957e3144181bc 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h @@ -38,7 +38,7 @@ class CpuTransferManager : public GenericTransferManager { ~CpuTransferManager() override {} Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) override; + const LiteralSlice& literal) override; Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, const void* source) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 801c5239081d174ba6278d54009e790863f3bcb9..af69fc3da9869aa2df958ecc5c064ee37dd9ea21 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -520,18 +521,259 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( } } +// This class implements a tiled matrix multiplication algorithm, intended for +// use as the innermost GEBP loop in a GEMM kernel (GEBP is described in "Goto, +// Kazushige, and Robert Van De Geijn. "High-performance implementation of the +// level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 (2008): +// 4). +// +// This only supports canonical dot operations (i.e. where the lhs contraction +// dimension is 1 and the rhs contraction dimension is 0) over row major +// matrices. +class MatrixMatrixBlockPanelEmitter { + public: + // Describe the dimensions of the GEBP kernel. These will usually not be the + // dimensions of the GEMM itself, the GEMM will usually be broken up into GEBP + // kernels with smaller dimensions. + class Dimensions { + public: + explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {} + + int64 m() const { return m_; } + int64 k() const { return k_; } + int64 n() const { return n_; } + + private: + const int64 m_; + const int64 k_; + const int64 n_; + }; + + // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies + // `lhs` with `rhs` and stores the result in `result`. + // + // `m`, `k` and `n` are the matrix multiplication dimensions. + // + // `max_vectorization_width` is the maximum vector width (i.e. the width of + // the largest vector register we will use). This can be larger than the + // largest vector register supported by the machine -- LLVM will legalize + // these large vector widths into legally sized vectors. + // `min_vectorization_width` is the smallest vector width the emitter will use + // -- below that it will devolve to using a scalar loop. + // + // `k_tiling_factor` is the number of elements along the reduction dimensions + // that we will attempt to process at once. + explicit MatrixMatrixBlockPanelEmitter( + llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, Dimensions dims, + int max_vectorization_width, int min_vectorization_width, + int k_tiling_factor, const TargetMachineFeatures& target_machine_features, + llvm::IRBuilder<>* ir_builder, PrimitiveType primitive_type) + : lhs_(lhs), + rhs_(rhs), + result_(result), + dims_(dims), + max_vectorization_width_(max_vectorization_width), + min_vectorization_width_(min_vectorization_width), + k_tiling_factor_(k_tiling_factor), + target_machine_features_(target_machine_features), + ir_builder_(ir_builder), + primitive_type_(primitive_type), + ksl_(ir_builder_) { + CHECK(max_vectorization_width > 0 && + IsPowerOfTwo(static_cast(max_vectorization_width))); + CHECK(min_vectorization_width > 0 && + IsPowerOfTwo(static_cast(min_vectorization_width))); + CHECK_GT(k_tiling_factor, 0); + } + + void Emit(); + + private: + // We can only iterate the `n` dimension for an extent that is divisible by + // the vectorization width. So we emit an outer loop that first processes the + // largest extent in `n` that is divisible by max_vectorization_width, then + // the largest remaining extent that is divisible by max_vectorization_width / + // 2 etc. This function emits that outermost loop. + void EmitChunkedLoopOverN(); + + // This emits a loop that loops over the `k` dimension in multiples of + // `k_tiling_factor` as much as possible and then emits a remainder epilogue. + void EmitLoopOverK(VectorSupportLibrary* vsl, llvm::Value* n_start, + llvm::Value* n_end); + + // This emits the inner reduction loop. This inner reduction loop processes + // all indices in the `m` dimension, [`k_start`, `k_end`) in the k dimension + // and [`n_start`, `n_end`) in the `n` dimension. + void EmitInnerLoop(int64 k_tiling_factor, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, + llvm::Value* n_end, VectorSupportLibrary* vsl); + + llvm::Value* getInt64(int64 value) { return ir_builder_->getInt64(value); } + + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* result_; + Dimensions dims_; + + int64 max_vectorization_width_; + int64 min_vectorization_width_; + int64 k_tiling_factor_; + + const TargetMachineFeatures& target_machine_features_; + llvm::IRBuilder<>* ir_builder_; + PrimitiveType primitive_type_; + KernelSupportLibrary ksl_; +}; + +void MatrixMatrixBlockPanelEmitter::Emit() { EmitChunkedLoopOverN(); } + +void MatrixMatrixBlockPanelEmitter::EmitChunkedLoopOverN() { + int64 current_vectorization_width = max_vectorization_width_; + int64 n_start = 0; + while (n_start != dims_.n() && + current_vectorization_width >= min_vectorization_width_) { + int64 n_end = dims_.n() - (dims_.n() % current_vectorization_width); + if (n_start != n_end) { + VectorSupportLibrary vsl(primitive_type_, current_vectorization_width, + ir_builder_, "gebp"); + EmitLoopOverK(&vsl, getInt64(n_start), getInt64(n_end)); + n_start = n_end; + } + current_vectorization_width /= 2; + } + + if (n_start != dims_.n()) { + VectorSupportLibrary vsl(primitive_type_, 1, ir_builder_, "gebp"); + ksl_.For("epi.n", n_start, dims_.n(), 1, [&](llvm::Value* n_i) { + llvm::Value* n_i_next = + ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1)); + EmitLoopOverK(&vsl, n_i, n_i_next); + }); + } +} + +void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl, + llvm::Value* n_start, + llvm::Value* n_end) { + int64 k_start = 0; + int64 k_end = dims_.k() - (dims_.k() % k_tiling_factor_); + if (k_end != k_start) { + EmitInnerLoop(k_tiling_factor_, getInt64(k_start), getInt64(k_end), n_start, + n_end, vsl); + k_start = k_end; + } + + if (k_start != dims_.k()) { + EmitInnerLoop(dims_.k() - k_start, getInt64(k_start), getInt64(dims_.k()), + n_start, n_end, vsl); + } +} + +// The tiling scheme is as follows: +// +// Let the LHS be: +// +// +---+---+---+ +// | a | b | c | . +// +---+---+---+ . +// | | | | . +// +---+---+---+ +// .. .. +// +// and the RHS be: +// +// +----+----+----+----+ +// | p0 | p1 | p2 | p3 | . +// +----+----+----+----+ . +// | q0 | q1 | q2 | q3 | . +// +----+----+----+----+ +// | r0 | r1 | r2 | r3 | . +// +----+----+----+----+ . +// ...... ...... +// +// and let k_tiling_factor be 3 and the vector width (implicitly denoted by +// `vsl`) be 4. +// +// Then we +// +// 1. broadcast the first row in LHS to 3 vectors of width 4 +// 2. elementwise multiply the RHS rows with these broadcasted vectors +// 3. elementwise add them: +// +// +---+---+---+---+ +----+----+----+----+ +// | a | a | a | a | * | p0 | p1 | p2 | p3 | + +// +---+---+---+---+ +----+----+----+----+ +// +// +---+---+---+---+ +----+----+----+----+ +// | b | b | b | b | * | q0 | q1 | q2 | q3 | + +// +---+---+---+---+ +----+----+----+----+ +// +// +---+---+---+---+ +----+----+----+----+ +// | c | c | c | c | * | r0 | r1 | r2 | r3 | +// +---+---+---+---+ +----+----+----+----+ +// +// to get: +// +// +----------------+----------------+----------------+----------------+ +// | a*p0+b*q0+c*r0 | a*p1+b*q1+c*r1 | a*p2+b*q2+c*r2 | a*p3+b*q3+c*r3 | +// +----------------+----------------+----------------+----------------+ +// +// which we increment into the appropriate region in the result. +void MatrixMatrixBlockPanelEmitter::EmitInnerLoop( + int64 k_tiling_factor, llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end, VectorSupportLibrary* vsl) { + ksl_.For("dot.m", 0, dims_.m(), 1, [&](llvm::Value* m_i) { + // This outer loop iterates over all of the M dimension + llvm::Value* result_row_begin = vsl->ComputeOffsetPointer( + result_, /*offset_elements=*/m_i, /*scale=*/dims_.n()); + llvm::Value* lhs_row_begin = vsl->ComputeOffsetPointer( + lhs_, /*offset_elements=*/m_i, /*scale=*/dims_.k()); + + ksl_.For("dot.k", k_start, k_end, k_tiling_factor, [&](llvm::Value* k_i) { + // broadcasted_a is the broadcasted set of vectors denoted as , + // etc. in the diagram. + std::vector broadcasted_a; + broadcasted_a.reserve(k_tiling_factor); + for (int i = 0; i < k_tiling_factor; i++) { + broadcasted_a.push_back(vsl->LoadBroadcast( + lhs_row_begin, ir_builder_->CreateAdd(getInt64(i), k_i))); + } + + // rhs_loader will be used to load the tile off of the RHS, denoted as + // <, ...> in the diagram. + TileLoader rhs_loader(vsl, ir_builder_, rhs_, dims_.n(), k_i, + k_tiling_factor); + ksl_.For( + "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { + // This loop iterates over the N dimension. It loads the tile from + // RHS, does the FMA resulting in the + // in the diagram and increments + // the result. + std::vector tile = rhs_loader.LoadTile(n_i); + llvm::Value* result_accumulator = + vsl->LoadVector(result_row_begin, n_i); + for (int i = 0; i < tile.size(); i++) { + result_accumulator = + vsl->MulAdd(tile[i], broadcasted_a[i], result_accumulator); + } + vsl->StoreVector(result_accumulator, result_row_begin, n_i); + }); + }); + }); +} + } // namespace -DotOpEmitter::DotOpEmitter( - const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, - const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features) +DotOpEmitter::DotOpEmitter(const HloInstruction& dot, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, + llvm::IRBuilder<>* ir_builder, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) : dot_(dot), - transpose_lhs_(transpose_lhs), - transpose_rhs_(transpose_rhs), target_array_(target_array), lhs_array_(lhs_array), rhs_array_(rhs_array), @@ -541,22 +783,80 @@ DotOpEmitter::DotOpEmitter( hlo_module_config_(hlo_module_config), target_machine_features_(target_machine_features) {} -/* static */ tensorflow::Status DotOpEmitter::EmitDotOperation( - const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, - const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, +/* static */ Status DotOpEmitter::EmitDotOperation( + const HloInstruction& dot, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { PrimitiveType type = target_array.GetShape().element_type(); TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type); - DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array, - lhs_array, rhs_array, addend_array, - executable_run_options_value, ir_builder, - hlo_module_config, target_machine_features); + DotOpEmitter dot_emitter(dot, target_array, lhs_array, rhs_array, + addend_array, executable_run_options_value, + ir_builder, hlo_module_config, + target_machine_features); return dot_emitter.Emit(); } +bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( + const DotOpEmitter::MatMultDims& mat_mult_dims) { + if (!EnableExperimentalLlvmIrGemm() || ShouldUseMultiThreadedEigen()) { + return false; + } + + if (mat_mult_dims.lhs_non_canonical || mat_mult_dims.rhs_non_canonical) { + return false; + } + + PrimitiveType primitive_type = dot_.shape().element_type(); + + switch (primitive_type) { + default: + return false; + + case F32: + case F64: + case S32: + case S64: + break; + } + + if (!(mat_mult_dims.lhs_column_major == mat_mult_dims.rhs_column_major && + mat_mult_dims.rhs_column_major == mat_mult_dims.target_column_major)) { + return false; + } + + VLOG(2) << "Emitting GEBP kernel in LLVM IR"; + + llvm::Value* lhs = lhs_array_.GetBasePointer(); + llvm::Value* rhs = rhs_array_.GetBasePointer(); + llvm::Value* target = target_array_.GetBasePointer(); + int64 m = mat_mult_dims.m; + int64 k = mat_mult_dims.k; + int64 n = mat_mult_dims.n; + + if (mat_mult_dims.lhs_column_major) { + std::swap(lhs, rhs); + std::swap(m, n); + } + + int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); + ir_builder_->CreateMemSet( + target, ir_builder_->getInt8(0), size_bytes, + target_machine_features_.minimum_alignment_for_allocation(size_bytes)); + + MatrixMatrixBlockPanelEmitter::Dimensions gebp_dims(/*m=*/m, /*k=*/k, + /*n=*/n); + MatrixMatrixBlockPanelEmitter gebp_emitter( + /*lhs=*/lhs, /*rhs=*/rhs, /*result=*/target, gebp_dims, + /*max_vectorization_width=*/8, /*min_vectorization_width=*/4, + /*k_tiling_factor=*/8, target_machine_features_, ir_builder_, + primitive_type); + gebp_emitter.Emit(); + return true; +} + bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (dot_.shape().dimensions_size() != 2) { return false; @@ -578,7 +878,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (mat_mult_dims.m == 1) { bool rhs_effectively_row_major = - transpose_rhs_ ^ !mat_mult_dims.rhs_column_major; + mat_mult_dims.rhs_non_canonical ^ !mat_mult_dims.rhs_column_major; if (rhs_effectively_row_major) { k = mat_mult_dims.k; m = mat_mult_dims.n; @@ -594,7 +894,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (mat_mult_dims.n == 1) { bool lhs_effectively_column_major = - transpose_lhs_ ^ mat_mult_dims.lhs_column_major; + mat_mult_dims.lhs_non_canonical ^ mat_mult_dims.lhs_column_major; if (lhs_effectively_column_major) { m = mat_mult_dims.m; k = mat_mult_dims.k; @@ -609,7 +909,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { } if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) { - return false; + return EmitExperimentalGebpDotIfEnabled(mat_mult_dims); } int64 tiling_factor = GetGemvTilingFactor(); @@ -690,7 +990,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { return true; } -tensorflow::Status DotOpEmitter::Emit() { +Status DotOpEmitter::Emit() { // The dot operation performs a sum of products over dimension 0 of the left // hand side operand and dimension 1 of the right hand side operand. // @@ -734,23 +1034,17 @@ tensorflow::Status DotOpEmitter::Emit() { CHECK_EQ(addend_array_, nullptr); - if (PotentiallyImplementedAsEigenDot(dot_)) { + if (PotentiallyImplementedAsEigenDot(dot_, target_machine_features_)) { return EmitCallToRuntime(); } // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special // case where the reduction dimension is 0 for both LHS and RHS. This results // in a vector dot product producing a scalar. - int64 lhs_reduction_dimension = 0; - if (ShapeUtil::Rank(lhs_shape) >= 2) { - lhs_reduction_dimension = - ShapeUtil::GetDimensionNumber(lhs_shape, transpose_lhs_ ? -2 : -1); - } - int64 rhs_reduction_dimension = 0; - if (ShapeUtil::Rank(rhs_shape) >= 2) { - rhs_reduction_dimension = - ShapeUtil::GetDimensionNumber(rhs_shape, transpose_rhs_ ? -1 : -2); - } + int64 lhs_reduction_dimension = + dot_.dot_dimension_numbers().lhs_contracting_dimensions(0); + int64 rhs_reduction_dimension = + dot_.dot_dimension_numbers().rhs_contracting_dimensions(0); // Verify the reduction dimension in the two operands are the same size. TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == @@ -874,10 +1168,10 @@ tensorflow::Status DotOpEmitter::Emit() { // loop. ir_builder_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status DotOpEmitter::EmitScalarDot() { +Status DotOpEmitter::EmitScalarDot() { // A scalar dot is just a scalar multiply. llvm::Value* result; llvm::Value* lhs_value = @@ -902,10 +1196,10 @@ tensorflow::Status DotOpEmitter::EmitScalarDot() { result = ir_builder_->CreateFMul(lhs_value, rhs_value); } target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status DotOpEmitter::EmitCallToRuntime() { +Status DotOpEmitter::EmitCallToRuntime() { // The signature of the Eigen runtime matmul function is: // // (void)(void* run_options, float* out, float* lhs, float* rhs, @@ -914,8 +1208,7 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { // The two transpose_... parameters are actually booleans, but we use int32 // to avoid target-dependent calling convention details. - bool multi_threaded = - hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + bool multi_threaded = ShouldUseMultiThreadedEigen(); bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn(); PrimitiveType type = target_array_.GetShape().element_type(); llvm::Type* float_type; @@ -986,8 +1279,8 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { const llvm_ir::IrArray* lhs = &lhs_array_; const llvm_ir::IrArray* rhs = &rhs_array_; - bool transpose_lhs = transpose_lhs_; - bool transpose_rhs = transpose_rhs_; + bool transpose_lhs = mat_mult_dims.lhs_non_canonical; + bool transpose_rhs = mat_mult_dims.rhs_non_canonical; if (!mat_mult_dims.lhs_column_major) { std::swap(mat_mult_dims.m, mat_mult_dims.n); @@ -1007,7 +1300,7 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { ir_builder_->getInt64(mat_mult_dims.k), ir_builder_->getInt32(transpose_lhs), ir_builder_->getInt32(transpose_rhs)}); - return tensorflow::Status::OK(); + return Status::OK(); } DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { @@ -1015,12 +1308,18 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); - - return {lhs_shape.dimensions(transpose_lhs_ ? 1 : 0), - lhs_shape.dimensions(transpose_lhs_ ? 0 : 1), - rhs_shape.dimensions(transpose_rhs_ ? 0 : 1), - LayoutUtil::Minor(lhs_shape.layout(), 0) == 0, - LayoutUtil::Minor(rhs_shape.layout(), 0) == 0}; + const DotDimensionNumbers& dim_nums = dot_.dot_dimension_numbers(); + + return { + /*m=*/lhs_shape.dimensions(1 - dim_nums.lhs_contracting_dimensions(0)), + /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)), + /*n=*/rhs_shape.dimensions(1 - dim_nums.rhs_contracting_dimensions(0)), + /*lhs_column_major=*/LayoutUtil::Minor(lhs_shape.layout(), 0) == 0, + /*lhs_non_canonical=*/dim_nums.lhs_contracting_dimensions(0) == 0, + /*rhs_column_major=*/LayoutUtil::Minor(rhs_shape.layout(), 0) == 0, + /*rhs_non_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 1, + /*target_column_major=*/ + LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0}; } llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( @@ -1060,19 +1359,39 @@ static bool IsRank2WithNoPadding(const Shape& shape) { // In a gemm operation where output = lhs * rhs, check whether the given shapes // are valid for the operation. -static bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape) { +static bool AreValidGemmShapes( + const Shape& lhs_shape, const Shape& rhs_shape, const Shape& output_shape, + const TargetMachineFeatures& target_machine_features) { // The inputs and the output must // 1) be matrices with no padding, and // 2) have an allowed element type. PrimitiveType output_primitive_type = output_shape.element_type(); - return (output_primitive_type == F64 || output_primitive_type == F32 || - output_primitive_type == F16) && - IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && - IsRank2WithNoPadding(output_shape); + if (!(output_primitive_type == F64 || output_primitive_type == F32 || + output_primitive_type == F16)) { + return false; + } + + if (!(IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && + IsRank2WithNoPadding(output_shape))) { + return false; + } + + auto is_aligned = [&](const Shape& shape) { + return GetMinimumAlignmentForArray(shape, target_machine_features) >= + TargetMachineFeatures::kEigenExpectedTensorAlignment; + }; + + if (!is_aligned(lhs_shape) || !is_aligned(rhs_shape) || + !is_aligned(output_shape)) { + return false; + } + + return true; } -bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { +bool PotentiallyImplementedAsEigenDot( + const HloInstruction& hlo, + const TargetMachineFeatures& target_machine_features) { // For certain types of Dot, we can call Eigen if (hlo.opcode() == HloOpcode::kDot) { const Shape& lhs_shape = hlo.operand(0)->shape(); @@ -1089,28 +1408,18 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { // If gemm can accept the operand shapes, use it rather than a custom // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { + if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape(), + target_machine_features)) { + const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers(); // The size of the reduction dimension should match. The shape inference // guarantees this invariant, so the check here is for programming // errors. - CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); + CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), + rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); return true; } } - if (hlo.opcode() == HloOpcode::kFusion && - hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && - hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { - auto* dot = hlo.fused_expression_root(); - const Shape& lhs_shape = dot->operand(0)->shape(); - const Shape& rhs_shape = dot->operand(1)->shape(); - if (ShapeUtil::HasZeroElements(lhs_shape) || - ShapeUtil::HasZeroElements(rhs_shape)) { - return false; - } - return true; - } - return false; } diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 47e09243340840980ebe21be3a2b056985877235..d88ccea0dbc845c0d9a580a5b118c57c888fb557 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -31,7 +31,9 @@ limitations under the License. namespace xla { namespace cpu { -bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo); +bool PotentiallyImplementedAsEigenDot( + const HloInstruction& hlo, + const TargetMachineFeatures& target_machine_features); // Returns the index for an operand to `hlo` that should ideally be column // major. Returns nullopt if there is no such operand or if `hlo` is not a dot @@ -55,17 +57,16 @@ class DotOpEmitter { // dimensions as the result, and the result is computed as `addend_array` + // dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported // for Matrix-vector products. - static tensorflow::Status EmitDotOperation( - const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, - const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, + static Status EmitDotOperation( + const HloInstruction& dot, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features); private: - DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, - bool transpose_rhs, const llvm_ir::IrArray& target_array, + DotOpEmitter(const HloInstruction& dot, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, @@ -75,18 +76,18 @@ class DotOpEmitter { const TargetMachineFeatures& target_machine_features); // Emits the IR to perform the dot operation. - tensorflow::Status Emit(); + Status Emit(); // Emits instructions to perform a scalar dot product (a multiply of the // LHS and RHS) and store the results in the target. - tensorflow::Status EmitScalarDot(); + Status EmitScalarDot(); // Emit an LLVM IR implementation of the dot operation if we can. Returns // true if an LLVM IR implementation was emitted. bool EmitLlvmIrDotIfProfitable(); // Emits a call to the CPU runtime to perform the matrix multiply. - tensorflow::Status EmitCallToRuntime(); + Status EmitCallToRuntime(); // Emits a series of nested loops for iterating over an operand array in the // dot operation. Loops are constructed in major to minor dimension layout @@ -111,11 +112,20 @@ class DotOpEmitter { // The number of columns on the RHS. int64 n; - // True if the LHS matrix column major. + // True if the LHS matrix is column major. bool lhs_column_major; - // True if the RHS matrix column major. + // True if the LHS contraction dimension is not 1. + bool lhs_non_canonical; + + // True if the RHS matrix is column major. bool rhs_column_major; + + // True if the RHS contraction dimension is not 0. + bool rhs_non_canonical; + + // True if the result matrix is column major. + bool target_column_major; }; // Get the MatMultDims instance for the dot product this DotOpEmitter @@ -123,6 +133,8 @@ class DotOpEmitter { // of rank 2 as well). MatMultDims GetMatMultDims() const; + bool EmitExperimentalGebpDotIfEnabled(const MatMultDims& mat_mult_dims); + // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector // registers. int64 GetGemvTilingFactor() const { @@ -131,9 +143,18 @@ class DotOpEmitter { .value_or(kDefaultTilingFactor); } + // Returns true if we should use an experimental implementation of GEMM + // (general matrix matrix multiplication) if possible. + bool EnableExperimentalLlvmIrGemm() const { + return options::EnableExperimentalLlvmIrGemm(hlo_module_config_); + } + + // Returns true if we should call into multi-threaded Eigen routines. + bool ShouldUseMultiThreadedEigen() { + return hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + } + const HloInstruction& dot_; - const bool transpose_lhs_; - const bool transpose_rhs_; const llvm_ir::IrArray& target_array_; const llvm_ir::IrArray& lhs_array_; const llvm_ir::IrArray& rhs_array_; diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc index 7dcc4ca7fa08b478f24065275ffa69725dc51682..c56286559158758ca6db5ae097729286bde346f0 100644 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc @@ -26,13 +26,13 @@ limitations under the License. namespace xla { namespace cpu { -void ExternalConstantPool::Insert(string name, const Literal& literal, +void ExternalConstantPool::Insert(string name, const LiteralSlice& literal, int64 alignment) { CHECK(!ShapeUtil::IsTuple(literal.shape())); CHECK(alignment > 0 && IsPowerOfTwo(static_cast(alignment))); CHECK(entries_.find(name) == entries_.end()); - int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); + const int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); void* raw_pointer = tensorflow::port::AlignedMalloc( literal_size, std::max(alignment, sizeof(void*))); CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h index 8008a56df4dbf16e7b57aee8a344058bb0d5883d..0677f5f0b58005079890052a426e5f48c5d09ed1 100644 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h @@ -43,7 +43,7 @@ class ExternalConstantPool { // The constant pool copies out the contents of `literal` into a buffer it // owns -- it does not keep pointers to `literal`, or to memory owned by // `literal`. - void Insert(string name, const Literal& literal, int64 alignment); + void Insert(string name, const LiteralSlice& literal, int64 alignment); // Find the constant with name `name` in this constant pool. If there isn't // such constant, return nullptr. diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index f209a69e3cd0f8d336d61bafd1e22be8bc88ca3f..b560b7531c0d24e6f670e61a15dce295d9fa2a49 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -24,8 +24,25 @@ limitations under the License. namespace xla { namespace cpu { +int64 GetMinimumAlignmentForArray( + const Shape& shape, const TargetMachineFeatures& target_machine_features) { + CHECK(ShapeUtil::IsArray(shape)); + CHECK(!LayoutUtil::HasLayout(shape) || LayoutUtil::IsDense(shape.layout())); + + // We don't require a layout to be set on `shape`. This only works on CPU + // because we don't pad our tensors or otherwise have complicated data tiling + // schemes. + + int64 allocation_size_bytes = + ShapeUtil::ElementsIn(shape) * + ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()); + return target_machine_features.minimum_alignment_for_allocation( + allocation_size_bytes); +} + bool PotentiallyImplementedAsEigenConvolution( - const HloInstruction& convolution) { + const HloInstruction& convolution, + const TargetMachineFeatures& target_machine_features) { // The following conditions are necessary (but not sufficient) for // implementing `convolution` with Eigen convolution: // - the input and kernel have a non-zero number of elements. @@ -35,6 +52,18 @@ bool PotentiallyImplementedAsEigenConvolution( // To be sufficient, certain layout constraints need to be satisfied as well. const Shape& input_shape = convolution.operand(0)->shape(); const Shape& kernel_shape = convolution.operand(1)->shape(); + const Shape& output_shape = convolution.shape(); + + auto is_aligned = [&](const Shape& shape) { + return GetMinimumAlignmentForArray(shape, target_machine_features) >= + TargetMachineFeatures::kEigenExpectedTensorAlignment; + }; + + if (!is_aligned(input_shape) || !is_aligned(kernel_shape) || + !is_aligned(output_shape)) { + return false; + } + if (ShapeUtil::HasZeroElements(input_shape) || ShapeUtil::HasZeroElements(kernel_shape)) { return false; @@ -71,7 +100,6 @@ bool PotentiallyImplementedAsEigenConvolution( } } - const Shape& output_shape = convolution.shape(); return dnums.input_batch_dimension() == 0 && dnums.input_feature_dimension() == input_shape.dimensions_size() - 1 && dnums.output_batch_dimension() == 0 && diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h index 34b2003916933f5ec0a15d9e219063c0a912fa40..68fbc7caaa9bfec0ecd7cc7f473c8ca8afce19db 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h @@ -17,13 +17,20 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { namespace cpu { bool PotentiallyImplementedAsEigenConvolution( - const HloInstruction& convolution); + const HloInstruction& convolution, + const TargetMachineFeatures& target_machine_features); + +// Computes the minimum alignment guaranteed for a tensor of shape `shape` on +// the target machine. +int64 GetMinimumAlignmentForArray( + const Shape& shape, const TargetMachineFeatures& target_machine_features); // Dynamic loop bounds are specified as an array of dimension index // [start, limit) pairs of ir values (one for each partitioned outer dimension). diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc index 215f48c4cc1a1a6b13d98dff76e0d1f0f773f5c1..abb2471e6ae6b2f2949ab2e91235e5047ae404f8 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" @@ -39,7 +40,12 @@ ENTRY Conv { HloComputation* entry_computation = module->entry_computation(); HloInstruction* conv_instr = entry_computation->root_instruction(); - EXPECT_FALSE(cpu::PotentiallyImplementedAsEigenConvolution(*conv_instr)); + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + EXPECT_FALSE(cpu::PotentiallyImplementedAsEigenConvolution( + *conv_instr, target_machine_features)); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 6347ee2a2a17502bb78ff3bdbba10ead35b72c40..13bd5e73db500e20b0e8c33bf921ee2457e126e5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -83,7 +83,7 @@ IrEmitter::IrEmitter( llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - llvm::TargetMachine* target_machine, + const TargetMachineFeatures* target_machine_features, ExternalConstantPool* external_constant_pool) : assignment_(assignment), module_(llvm_module), @@ -94,7 +94,7 @@ IrEmitter::IrEmitter( alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), hlo_module_config_(hlo_module.config()), is_top_level_computation_(false), - target_machine_features_(target_machine), + target_machine_features_(*target_machine_features), external_constant_pool_(external_constant_pool) { ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() @@ -227,32 +227,6 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { } } -// Calculate the alignment of a buffer with a particular size. -int IrEmitter::MinimumAlignmentForBufferSize(int64 buffer_size) { - // GLibc returns a pointer with alignment 8 on 32-bit platforms and 16 on - // 64-bit platforms. TCMalloc returns a pointer with alignment 8 for - // allocations smaller than kMallocAlignmentThreshold bytes and at least - // alignment 16 for allocations greater than or equal to - // kMallocAlignmentThreshold bytes. N.B. We could improve on this lower bound - // by explicitly allocating the memory with posix_memalign. This is - // complicated by our desire to allow parameter buffers created by clients to - // be consumed directly by the JIT. - if (buffer_size == 0) { - // No need to align empty buffers. - return 1; - } - - const int64 kMallocAlignmentThreshold = 512; - - int pointer_size = module_->getDataLayout().getPointerSize(); - int buffer_alignment = buffer_size >= kMallocAlignmentThreshold - ? 2 * pointer_size - : pointer_size; - DCHECK_GT(buffer_alignment, 0); - - return buffer_alignment; -} - // Calculate the alignment of a buffer allocated for a given primitive type. int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) { int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); @@ -277,7 +251,7 @@ int IrEmitter::MinimumAlignmentForShape(const Shape& shape) { DCHECK_GE(buffer_size, 0); DCHECK_LE(buffer_size, SIZE_MAX); - return MinimumAlignmentForBufferSize(buffer_size); + return target_machine_features_.minimum_alignment_for_allocation(buffer_size); } void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, @@ -290,7 +264,8 @@ void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, int64 buffer_size) { - int alignment = MinimumAlignmentForBufferSize(buffer_size); + int alignment = + target_machine_features_.minimum_alignment_for_allocation(buffer_size); if (alignment > 1) { llvm_ir::SetAlignmentMetadataForLoad(load, alignment); } @@ -530,7 +505,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { HloComputation* function = reduce_window->to_apply(); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*reduce_window, /*operands=*/{operand}, - /*supported_types=*/{F32, BF16})); + /*supported_types=*/{F32, BF16, S32})); // TODO(b/31410564): Implement dilation for reduce-window. if (window_util::HasDilation(window)) { @@ -827,13 +802,6 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { "Dot with multiple contracting dimensions not implemented."); } - if (dnums.lhs_contracting_dimensions(0) != - std::min(lhs->shape().dimensions_size() - 1, 1) || - dnums.rhs_contracting_dimensions(0) != 0) { - return Unimplemented( - "Dot with non-standard contracting dimensions not implemented."); - } - llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); @@ -850,8 +818,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Dot operation is complicated so we delegate to a helper class. return DotOpEmitter::EmitDotOperation( - *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, - lhs_array, rhs_array, /*addend_array=*/nullptr, + *dot, target_array, lhs_array, rhs_array, /*addend_array=*/nullptr, GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_, target_machine_features_); } @@ -869,7 +836,8 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support // different data layouts. - if (PotentiallyImplementedAsEigenConvolution(*convolution)) { + if (PotentiallyImplementedAsEigenConvolution(*convolution, + target_machine_features_)) { const Shape& lhs_shape = lhs->shape(); const Shape& rhs_shape = rhs->shape(); const Shape& convolution_shape = convolution->shape(); @@ -1035,12 +1003,14 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // We will accumulate the products into this sum to calculate // the output entry at the given index. PrimitiveType lhs_element_type = lhs->shape().element_type(); + llvm::Type* lhs_llvm_type = + llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_); llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_), - "convolution_sum_address", &ir_builder_, + lhs_llvm_type, "convolution_sum_address", &ir_builder_, MinimumAlignmentForPrimitiveType(lhs_element_type)); - ir_builder_.CreateStore( - llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), sum_address); + llvm::Value* constant_zero = + llvm::Constant::getNullValue(lhs_llvm_type); + ir_builder_.CreateStore(constant_zero, sum_address); llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &ir_builder_); std::vector kernel_spatial(num_spatial_dims); @@ -1194,7 +1164,13 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type, int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); - const char* fn_name = runtime::kEigenFftSymbolName; + + bool multi_threaded_eigen = + hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + const char* fn_name = multi_threaded_eigen + ? runtime::kEigenFftSymbolName + : runtime::kEigenSingleThreadedFftSymbolName; + llvm::Function* fft_func = llvm::cast( module_->getOrInsertFunction(fn_name, fft_type)); fft_func->setCallingConv(llvm::CallingConv::C); @@ -1216,16 +1192,45 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { } Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { - if (hlo_module_config_.replica_count() == 1) { - // When there is a single replica, a cross replica sum is the identity - // function, and the buffer assignment expects a copy (we could eliminate - // these at the HLO level as an optimization). - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); + if (hlo_module_config_.replica_count() != 1) { + // TODO(b/33011107): Support nontrivial cross replica sum on CPU. + return Unimplemented( + "CrossReplicaSum with >1 replica is not implemented on CPU."); + } + + // When there is a single replica, a cross replica sum is the identity + // function, and the buffer assignment expects a copy. + // + // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely + // in algebraic-simplifier, but currently on some platforms + // HloModuleConfig::num_replicas changes between when the module is compiled + // and when it's run. + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); + + // CRS with one operand and one replica is simply the identity function. + if (crs->operand_count() == 1) { return EmitMemcpy(*crs->operand(0), *crs); } - // TODO(b/33011107): Support cross replica sum on CPU. - return Unimplemented("CrossReplicaSum is not implemented on CPU."); + // CRS with multiple operands and one replica produces a (one-deep) tuple. + std::vector operand_ptrs; + for (int64 i = 0; i < crs->operand_count(); ++i) { + llvm::Value* in_ptr = GetEmittedValueFor(crs->operand(i)); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice, + assignment_.GetUniqueSlice(crs, {i})); + + const Shape& operand_shape = crs->operand(i)->shape(); + CHECK(ShapeUtil::IsArray(operand_shape)) + << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); + operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape)); + + // TODO(b/63762267): Be more aggressive about specifying alignment. + ir_builder_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, + /*SrcAlign=*/1, + ShapeUtil::ByteSizeOf(operand_shape)); + } + llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &ir_builder_, module_); + return Status::OK(); } // Fills up the free variables in 'index_with_free_var' with values from @@ -2086,44 +2091,7 @@ static const HloInstruction* StripTranspose(const HloInstruction& hlo) { Status IrEmitter::HandleFusion(HloInstruction* fusion) { auto* root = fusion->fused_expression_root(); - if (fusion->fusion_kind() == HloInstruction::FusionKind::kTransposeDot) { - DCHECK(root->opcode() == HloOpcode::kDot); - const HloInstruction* lhs_parameter = StripTranspose(*root->operand(0)); - const HloInstruction* rhs_parameter = StripTranspose(*root->operand(1)); - DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && - rhs_parameter->opcode() == HloOpcode::kParameter); - const HloInstruction* lhs = - fusion->operand(lhs_parameter->parameter_number()); - const HloInstruction* rhs = - fusion->operand(rhs_parameter->parameter_number()); - - TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( - /*instruction=*/*root, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F16, F32, F64})); - - llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); - llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); - - Shape target_shape = fusion->shape(); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); - llvm_ir::IrArray target_array = GetIrArrayFor(fusion); - VLOG(2) << "HandleFusion kTransposeDot: "; - VLOG(2) << " lhs operand: " - << llvm_ir::DumpToString(*lhs_array.GetBasePointer()); - VLOG(2) << " rhs operand: " - << llvm_ir::DumpToString(*rhs_array.GetBasePointer()); - VLOG(2) << " target: " - << llvm_ir::DumpToString(*target_array.GetBasePointer()); - - // Dot operation is complicated so we delegate to a helper class. - TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( - *root, root->operand(0)->IsRank2Transpose(), - root->operand(1)->IsRank2Transpose(), target_array, lhs_array, - rhs_array, /*addend_array=*/nullptr, GetExecutableRunOptionsArgument(), - &ir_builder_, hlo_module_config_, target_machine_features_)); - return Status::OK(); - } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, - assignment_)) { + if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) { VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace"; CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); @@ -2166,9 +2134,9 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { GetIrArrayFor(fusion->operand(addend_param_number))); TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( - *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, - lhs_array, rhs_array, &addend_array, GetExecutableRunOptionsArgument(), - &ir_builder_, hlo_module_config_, target_machine_features_)); + *dot, target_array, lhs_array, rhs_array, &addend_array, + GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_, + target_machine_features_)); return Status::OK(); } else { return Unimplemented("Fusion kind not implemented on CPU"); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 5a040760804fa5609e1d68511d4b2abe8e2ec8f9..f49cfc1dc378bb80da3ddf995363acfa2081067b 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -76,7 +76,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - llvm::TargetMachine* target_machine, + const TargetMachineFeatures* target_machine, ExternalConstantPool* external_constant_pool); ~IrEmitter() override; @@ -514,9 +514,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Calculate the alignment of a buffer allocated for a given primitive type. int MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type); - // Calculate the alignment of a buffer with a particular size. - int MinimumAlignmentForBufferSize(int64 buffer_size); - // Returns the number of bytes within the shape. int64 ByteSizeOf(const Shape& shape) const; @@ -536,7 +533,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { bool is_top_level_computation_; - TargetMachineFeatures target_machine_features_; + const TargetMachineFeatures& target_machine_features_; int64 external_global_constant_counter_ = 0; ExternalConstantPool* external_constant_pool_; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index fb28280fade307ac1f193e7dca481bd2afa855fc..63d0f7b95c7e45913c707471dbe2dc62e05251d6 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -104,7 +104,9 @@ class DefaultCostModel : public ParallelCostModel { ParallelTaskAssignment::ParallelTaskAssignment( const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module) { + const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module, + const TargetMachineFeatures* target_machine_features) + : target_machine_features_(*target_machine_features) { VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism; // Run cost analysis on 'module'. auto cost_analysis = MakeUnique(shape_size); @@ -127,7 +129,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( // Currently, we do not assign parallel tasks to instructions with at least // one of the following properties: // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall). - // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). + // *) Emit custom loops (kSelectAndScatter). // *) Operations that are not thread safe (like infeed and rng). // *) Tuple-shaped. // TODO(b/27458679) Parallelize instructions which are skipped here. @@ -139,8 +141,10 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng || (opcode == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction)) || - PotentiallyImplementedAsEigenDot(*instruction) || + PotentiallyImplementedAsEigenConvolution(*instruction, + target_machine_features_)) || + PotentiallyImplementedAsEigenDot(*instruction, + target_machine_features_) || (opcode == HloOpcode::kFusion && instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || ShapeUtil::IsTuple(instruction->shape())) { @@ -231,7 +235,8 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper( void ParallelTaskAssigner::ComputeTargetParallelTasks( HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) { ParallelTaskAssignment parallel_task_assignment(max_parallelism_, - shape_size_function_, module); + shape_size_function_, module, + &target_machine_features_); // Compute parallel task counts for all instructions in 'module'. for (auto* computation : module->computations()) { diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index 7140dabe516cd7ea9260456e994e8b63b68c60d6..8becc8fa23424d7454cc783eb9d853aecb5d053b 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -39,7 +40,8 @@ class ParallelTaskAssignment { // 'module': the containing HloModule. ParallelTaskAssignment(const int64 max_parallelism, const HloCostAnalysis::ShapeSizeFunction& shape_size, - HloModule* module); + HloModule* module, + const TargetMachineFeatures* target_machine_features); ~ParallelTaskAssignment() {} // Computes and returns the target parallel task count for 'instruction'. @@ -47,6 +49,7 @@ class ParallelTaskAssignment { private: std::unique_ptr cost_model_; + const TargetMachineFeatures& target_machine_features_; }; // ParallelTaskAssigner computes target parallel task counts for all HLOs @@ -63,8 +66,11 @@ class ParallelTaskAssigner : public HloPassInterface { // 'shape_size': shape size function used by HloCostAnalysis during parallel // task assignment. ParallelTaskAssigner(const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size) - : max_parallelism_(max_parallelism), shape_size_function_(shape_size) {} + const HloCostAnalysis::ShapeSizeFunction& shape_size, + const TargetMachineFeatures* target_machine_features) + : max_parallelism_(max_parallelism), + shape_size_function_(shape_size), + target_machine_features_(*target_machine_features) {} ~ParallelTaskAssigner() override {} tensorflow::StringPiece name() const override { @@ -94,6 +100,7 @@ class ParallelTaskAssigner : public HloPassInterface { int64 max_parallelism_; HloCostAnalysis::ShapeSizeFunction shape_size_function_; + const TargetMachineFeatures& target_machine_features_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index 13eb75a57213b1a68a5732a4f6061efdf97fa4f4..fc2efbaf9a22b02cd729da2f367d53bc15506836 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -31,6 +32,19 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase { // Use any value larger than 2 since we only test whether a module is // parallelized or not const int max_parallelism_ = 10; + + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; + + ParallelTaskAssignmentTest() + : target_machine_features_([](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }) {} + + StatusOr RunParallelTaskAssigner(HloModule* module) { + return cpu::ParallelTaskAssigner(max_parallelism_, shape_size_func_, + &target_machine_features_) + .Run(module); + } }; TEST_F(ParallelTaskAssignmentTest, DotOperationNotParallelized) { @@ -45,9 +59,7 @@ TEST_F(ParallelTaskAssignmentTest, DotOperationNotParallelized) { )"; ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( - max_parallelism_, shape_size_func_) - .Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); EXPECT_FALSE(changed); } @@ -74,9 +86,7 @@ TEST_F(ParallelTaskAssignmentTest, )"; ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( - max_parallelism_, shape_size_func_) - .Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); EXPECT_FALSE(changed); } @@ -92,9 +102,7 @@ TEST_F(ParallelTaskAssignmentTest, RngOperationNotParallelized) { )"; ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( - max_parallelism_, shape_size_func_) - .Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); EXPECT_FALSE(changed); } @@ -108,9 +116,7 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { )"; ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( - max_parallelism_, shape_size_func_) - .Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); EXPECT_FALSE(changed); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h index 984cb0616e02475babad7160d0f43bb23de0b50e..0bf693edd0b985a4e62c16414646cc6a17db26ee 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h @@ -21,8 +21,6 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/numeric_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/types.h" // 'tensorflow' namespace is used so that int64 and other types don't require @@ -71,11 +69,9 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, in_dims[0] = input_batch; Eigen::DSizes out_dims; out_dims[0] = input_batch; - TensorShape temp_shape{input_batch}; for (int i = 0; i < FFTRank; i++) { in_dims[i + 1] = fft_shape[i]; out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; - temp_shape.AddDim(fft_shape[i]); } const Eigen::TensorMap, Eigen::Aligned> @@ -88,8 +84,8 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); // Compute the full FFT using a temporary tensor. - Tensor temp(DataTypeToEnum::v(), temp_shape); - auto full_fft = temp.flat_inner_dims(); + Eigen::Tensor full_fft(in_dims); + const Eigen::DSizes zero_start_indices; full_fft.device(device) = input.template fft(axes); @@ -112,11 +108,9 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, in_dims[0] = input_batch; Eigen::DSizes out_dims; out_dims[0] = input_batch; - TensorShape temp_shape{input_batch}; for (int i = 0; i < FFTRank; i++) { in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; out_dims[i + 1] = fft_shape[i]; - temp_shape.AddDim(fft_shape[i]); } const Eigen::TensorMap, Eigen::Aligned> @@ -129,8 +123,7 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, // region we will slice from input given fft_shape. We slice input to // fft_shape on its inner-most dimensions, except the last (which we // slice to fft_shape[-1] / 2 + 1). - Tensor temp(DataTypeToEnum::v(), temp_shape); - auto full_fft = temp.flat_inner_dims(); + Eigen::Tensor full_fft(out_dims); // Calculate the starting point and range of the source of // negative frequency part. @@ -179,7 +172,6 @@ template void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, int32 fft_type, int64 input_batch, int64 fft_length0, int64 fft_length1, int64 fft_length2) { - CHECK(::xla::FftType_IsValid(fft_type)) << fft_type; switch (fft_type) { case ::xla::FftType::FFT: EigenFftC2C( @@ -204,7 +196,8 @@ void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, input_batch, fft_length0, fft_length1, fft_length2); break; default: - LOG(FATAL) << "Unsupported FFT type: " << fft_type; + // Unsupported FFT type + abort(); } } @@ -230,7 +223,8 @@ void EigenFftImpl(const EigenDevice& device, void* out, void* operand, fft_length1, fft_length2); break; default: - LOG(FATAL) << "Unsupported FFT rank " << fft_rank; + // Unsupported FFT rank + abort(); } } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc new file mode 100644 index 0000000000000000000000000000000000000000..2613ddb12704aea7d0884c6c8c062dc028383639 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" + +#include "tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int32; +using tensorflow::int64; + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedFft( + const void* run_options_ptr, void* out, void* operand, int32 fft_type, + int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + tensorflow::xla::EigenFftImpl(Eigen::DefaultDevice(), out, operand, fft_type, + fft_rank, input_batch, fft_length0, fft_length1, + fft_length2); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h new file mode 100644 index 0000000000000000000000000000000000000000..dcd133d012cf074a4cd2f550585881388bea6156 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +extern void __xla_cpu_runtime_EigenSingleThreadedFft( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out, + void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank, + tensorflow::int64 input_batch, tensorflow::int64 fft_length0, + tensorflow::int64 fft_length1, tensorflow::int64 fft_length2); + +} // extern "C" + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index ff6f0a9d4e443c2ed7d2dd6c58f4aaf28205b0cb..c4c90515ac7ec2721cb9ea48d42e3c5080e249af 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" #include "tensorflow/compiler/xla/types.h" @@ -73,23 +74,33 @@ llvm::StringRef GetHostCpuName() { } } // namespace +/*static*/ std::unique_ptr +SimpleOrcJIT::InferTargetMachineForJIT( + const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level) { + std::unique_ptr target_machine( + llvm::EngineBuilder() + .setTargetOptions(target_options) + .setOptLevel(opt_level) + .selectTarget( + /*TargetTriple=*/llvm::Triple(), /*MArch=*/"", + /*MCPU=*/GetHostCpuName(), + /*MAttrs=*/DetectMachineAttributes())); + CHECK(target_machine != nullptr); + return target_machine; +} + SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, bool enable_fast_math, bool disable_expensive_passes, LLVMCompiler::ModuleHook pre_optimization_hook, LLVMCompiler::ModuleHook post_optimization_hook) - : target_machine_( - CHECK_NOTNULL(llvm::EngineBuilder() - .setTargetOptions(target_options) - .setOptLevel(opt_level) - .selectTarget( - /*TargetTriple=*/llvm::Triple(), /*MArch=*/"", - /*MCPU=*/GetHostCpuName(), - /*MAttrs=*/DetectMachineAttributes()))), + : target_machine_(InferTargetMachineForJIT(target_options, opt_level)), disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), symbol_resolver_(llvm::orc::createLegacyLookupResolver( + execution_session_, [this](const std::string& name) -> llvm::JITSymbol { return this->ResolveRuntimeSymbol(name); }, @@ -192,6 +203,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedFft); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index f4260a95bc45557b6cd969f7d3fff01c8b392575..1851a3ee0bb97b4860605d7211a6ae70ac88686b 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -95,6 +95,12 @@ class SimpleOrcJIT { return &external_constant_pool_; } + // Creates an llvm::TargetMachine suitable for JITting code that will run on + // the current machine. + static std::unique_ptr InferTargetMachineForJIT( + const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level); + private: llvm::JITSymbol ResolveRuntimeSymbol(const std::string& name); diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc index eeb049737dddd11ef2ce229df772baec3ac03dd8..a0cd8ee2d2be10bcee9c2e216e24908d949e2d7b 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc @@ -18,7 +18,7 @@ limitations under the License. namespace xla { namespace cpu { -llvm::TargetTransformInfo* TargetMachineFeatures::GetTargetTransformInfoFor( +llvm::TargetTransformInfo* LLVMTargetMachineFeatures::GetTargetTransformInfoFor( const llvm::Function& function) const { auto it = target_transform_info_cache_.find(&function); if (it == target_transform_info_cache_.end()) { @@ -31,5 +31,30 @@ llvm::TargetTransformInfo* TargetMachineFeatures::GetTargetTransformInfoFor( return &it->second; } +int64 LLVMTargetMachineFeatures::minimum_alignment_for_allocation( + int64 size_bytes) const { + // GLibc malloc returns a pointer with alignment 8 on 32-bit platforms and 16 + // on 64-bit platforms. TCMalloc returns a pointer with alignment 8 for + // allocations smaller than kMallocAlignmentThreshold bytes and at least + // alignment 16 for allocations greater than or equal to + // kMallocAlignmentThreshold bytes. N.B. We could improve on this lower bound + // by explicitly allocating the memory with posix_memalign. This is + // complicated by our desire to allow parameter buffers created by clients to + // be consumed directly by the JIT. + if (size_bytes == 0) { + // No need to align empty buffers. + return 1; + } + + const int64 kMallocAlignmentThreshold = 512; + + int pointer_size = target_machine_->getPointerSize(0); + int buffer_alignment = + size_bytes >= kMallocAlignmentThreshold ? 2 * pointer_size : pointer_size; + DCHECK_GT(buffer_alignment, 0); + + return buffer_alignment; +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.h b/tensorflow/compiler/xla/service/cpu/target_machine_features.h index 703942615e552dccde7ddec8c8b90e8a486652af..8b00ae9e47eeed26ffe80707b89593b267e8dbb8 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.h +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.h @@ -24,43 +24,68 @@ limitations under the License. namespace xla { namespace cpu { -// Wraps an llvm::TargetMachine and parses out some information that feeds into -// LLVM IR code generation decisions. +// Abstract interface for classes providing information about the target we're +// compiling for. class TargetMachineFeatures { public: static constexpr int kX86AvxVectorByteSize = 32; - TargetMachineFeatures(llvm::TargetMachine* target_machine) - : target_machine_(target_machine) {} + // Input and output tensor buffers must be aligned to this many bytes if we + // want to call an Eigen backed GEMM or Convolution. + static constexpr int kEigenExpectedTensorAlignment = 16; // Return the vectorization factor, which is the number of bytes of data // explicitly vectorized routines will try to process at once. - int vectorization_factor_in_bytes() const { - // Ideally this should be a function of the cache line size (which we can - // get from llvm::TargetTransformInfo::getCacheLineSize) of the target - // machine. Guess a value of 128 bytes for now. - return 128; - } + virtual int vectorization_factor_in_bytes() const = 0; // Return the size of the largest vector size in bytes. We need to pass in // "function" since llvm functions can contain annotations for specializing // them to specific micro-architectures (though currently XLA does not use // this functionality). - int vector_register_byte_size(const llvm::Function& function) const { - llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function); - return tti->getRegisterBitWidth(/*Vector=*/true) / 8; - } + virtual int vector_register_byte_size( + const llvm::Function& function) const = 0; // Return the number of elements of type `type` that can fit into the largest // vector register available. We need to pass in "function" since llvm // functions can contain annotations for specializing them to specific // micro-architectures (though currently XLA does not use this functionality). + virtual int vector_register_num_elements(const llvm::Function& function, + PrimitiveType type) const = 0; + + // Returns the minimum alignment for a buffer of size size_bytes. + virtual int64 minimum_alignment_for_allocation(int64 size_bytes) const = 0; + + virtual ~TargetMachineFeatures() = default; +}; + +// Implements the TargetMachineFeatures interface using an llvm::TargetMachine. +class LLVMTargetMachineFeatures : public TargetMachineFeatures { + public: + static constexpr int kX86AvxVectorByteSize = 32; + + LLVMTargetMachineFeatures(llvm::TargetMachine* target_machine) + : target_machine_(target_machine) {} + + int vectorization_factor_in_bytes() const override { + // Ideally this should be a function of the cache line size (which we can + // get from llvm::TargetTransformInfo::getCacheLineSize) of the target + // machine. Guess a value of 128 bytes for now. + return 128; + } + + int vector_register_byte_size(const llvm::Function& function) const override { + llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function); + return tti->getRegisterBitWidth(/*Vector=*/true) / 8; + } + int vector_register_num_elements(const llvm::Function& function, - PrimitiveType type) const { + PrimitiveType type) const override { return vector_register_byte_size(function) / (primitive_util::BitWidth(type) / 8); } + int64 minimum_alignment_for_allocation(int64 size_bytes) const override; + private: llvm::TargetTransformInfo* GetTargetTransformInfoFor( const llvm::Function& function) const; diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h b/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h new file mode 100644 index 0000000000000000000000000000000000000000..ffc6927cbe1a2b6fd1a1ca3aac9b6e047741c2af --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_FAKE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_FAKE_H_ + +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" + +namespace xla { +namespace cpu { +// Delegates calls to minimum_alignment_for_allocation to a user provided +// std::function, crashes on all other methods. +// +// Primarily useful for testing. +class TargetMachineFeaturesWithFakeAlignmentLogic + : public TargetMachineFeatures { + public: + explicit TargetMachineFeaturesWithFakeAlignmentLogic( + std::function fake_alignment_logic) + : fake_alignment_logic_(std::move(fake_alignment_logic)) {} + + int vectorization_factor_in_bytes() const override { + LOG(FATAL) << "Unexpected call to " << __func__; + } + + int vector_register_byte_size(const llvm::Function& function) const override { + LOG(FATAL) << "Unexpected call to " << __func__; + } + + int vector_register_num_elements(const llvm::Function& function, + PrimitiveType type) const override { + LOG(FATAL) << "Unexpected call to " << __func__; + } + + int64 minimum_alignment_for_allocation(int64 size_bytes) const override { + return fake_alignment_logic_(size_bytes); + } + + private: + std::function fake_alignment_logic_; +}; +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_FAKE_H_ diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 18a915e5339623c73fee0e339fe75ee405898a36..67f776e7b5883f425b41c05342b74bebe223e17f 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -32,7 +32,6 @@ cc_library( deps = [ "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index 6479bf76aab581ae3ec2923d98dab53720cab203..edcaec584997b17dce30b8c46fda4abc78441064 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -143,6 +143,12 @@ class VectorSupportLibrary { llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, llvm::Value* offset_elements); + llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, + llvm::Value* offset_elements, int64 scale) { + return ComputeOffsetPointer( + base_pointer, + ir_builder_->CreateMul(ir_builder_->getInt64(scale), offset_elements)); + } llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, int64 offset_elements) { return ComputeOffsetPointer(base_pointer, diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc index 35db4fd2a22cc1615ade77a801cb28c504db09a6..e228bb56bce8febcca28ae171f6de90973d020ab 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.cc +++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc @@ -29,7 +29,7 @@ StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator( : DeviceMemoryAllocator(platform), stream_executors_(stream_executors.begin(), stream_executors.end()) {} -StatusOr StreamExecutorMemoryAllocator::Allocate( +StatusOr StreamExecutorMemoryAllocator::Allocate( int device_ordinal, uint64 size, bool retry_on_failure) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * stream_executor, GetStreamExecutor(device_ordinal)); @@ -40,22 +40,17 @@ StatusOr StreamExecutorMemoryAllocator::Allocate( tensorflow::strings::HumanReadableNumBytes(size).c_str(), size, device_ordinal); } - return result; + return OwningDeviceMemory(result, device_ordinal, this); } -tensorflow::Status StreamExecutorMemoryAllocator::Deallocate( - int device_ordinal, se::DeviceMemoryBase* mem) { - if (!mem->is_null()) { +Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal, + se::DeviceMemoryBase mem) { + if (!mem.is_null()) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * stream_executor, GetStreamExecutor(device_ordinal)); - // We make a local copy of 'mem' so the original is not zeroed out by the - // Deallocate() call below. This gives us a better chance of - // catching double-free bugs, since Deallocate silently succeeds for null - // values. - se::DeviceMemoryBase mem_copy(*mem); - stream_executor->Deallocate(&mem_copy); + stream_executor->Deallocate(&mem); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr StreamExecutorMemoryAllocator::GetStreamExecutor( diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h index da45c4d45a1c56fd39b1e3e17ff131de59ceeced..d87b86caf0d3acaa5bf9a455cff2315cedb2496d 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.h +++ b/tensorflow/compiler/xla/service/device_memory_allocator.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/owning_device_memory.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -37,28 +38,29 @@ class DeviceMemoryAllocator { : platform_(platform) {} virtual ~DeviceMemoryAllocator() {} + // Allocates memory on the device. + // + // If size > 0 and the returned StatusOr is OK, the wrapped OwningDeviceMemory + // must not be null. If size == 0, must return a null OwningDeviceMemory. + // // 'retry_on_failure': If false, and the first attempt to allocate the memory // fails, the allocation should return immediately without retrying. An // example use case is optional scratch spaces where a failure has only // performance impact. - // - // Allocate() should return a null pointer for a size-0 allocation. - // Deallocate() must be a no-op for null pointers. - virtual StatusOr Allocate(int device_ordinal, - uint64 size, - bool retry_on_failure) = 0; + virtual StatusOr Allocate(int device_ordinal, uint64 size, + bool retry_on_failure) = 0; // Two-arg version of Allocate(), which sets retry-on-failure to true. // // (We don't simply use a default argument on the virtual Allocate function // because default args on virtual functions are disallowed by the Google // style guide.) - StatusOr Allocate(int device_ordinal, uint64 size) { + StatusOr Allocate(int device_ordinal, uint64 size) { return Allocate(device_ordinal, size, /*retry_on_failure=*/true); } - virtual tensorflow::Status Deallocate(int device_ordinal, - se::DeviceMemoryBase* mem) = 0; + // Must be a nop for null pointers. + virtual Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) = 0; // Return the platform that the allocator allocates memory on. const se::Platform* platform() const { return platform_; } @@ -68,6 +70,7 @@ class DeviceMemoryAllocator { virtual bool AllowsAsynchronousDeallocation() const = 0; protected: + friend class OwningDeviceMemory; const se::Platform* platform_; }; @@ -79,14 +82,13 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { const se::Platform* platform, tensorflow::gtl::ArraySlice stream_executors); - StatusOr Allocate(int device_ordinal, uint64 size, - bool retry_on_failure) override; + StatusOr Allocate(int device_ordinal, uint64 size, + bool retry_on_failure) override; // Pull in two-arg overload that sets retry_on_failure to true. using DeviceMemoryAllocator::Allocate; - tensorflow::Status Deallocate(int device_ordinal, - se::DeviceMemoryBase* mem) override; + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; bool AllowsAsynchronousDeallocation() const override; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 0528b076027603796a445d8b0e9cbcebd1b513a7..b9d7ec9c2e17e560580fcea060bf552c42fe3b3c 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -138,6 +138,9 @@ class DfsHloVisitorBase { virtual Status HandleExp(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } + virtual Status HandleExpm1(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleFloor(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } @@ -150,6 +153,9 @@ class DfsHloVisitorBase { virtual Status HandleClz(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } + virtual Status HandleLog1p(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleCos(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index ae32d33766093cf4e610a0dc05f7d8c88cb37d31..9a8bab353ef6b1e0b05b250d35296bc3cef8bc37 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -418,8 +418,12 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } case HloOpcode::kExp: return EmitExp(op->shape().element_type(), operand_value); + case HloOpcode::kExpm1: + return EmitExpm1(op->shape().element_type(), operand_value); case HloOpcode::kLog: return EmitLog(op->shape().element_type(), operand_value); + case HloOpcode::kLog1p: + return EmitLog1p(op->shape().element_type(), operand_value); case HloOpcode::kCos: return EmitCos(op->shape().element_type(), operand_value); case HloOpcode::kSin: @@ -493,6 +497,22 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( return EmitComposeComplex( op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); } + case HloOpcode::kLog1p: { + // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + llvm::Type* llvm_ty = a->getType(); + auto one = llvm::ConstantFP::get(llvm_ty, 1.0); + auto a_plus_one = ir_builder_->CreateFAdd(a, one); + auto sum_sq = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(a_plus_one, a_plus_one), + ir_builder_->CreateFMul(b, b)); + TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); + TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one)); + auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); + return EmitComposeComplex( + op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); + } case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); TF_RET_CHECK(primitive_util::IsComplexType(from_type)); @@ -523,6 +543,20 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), ir_builder_->CreateFMul(exp_a, sin_b)); } + case HloOpcode::kExpm1: { + // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i + TF_ASSIGN_OR_RETURN( + auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value))); + TF_ASSIGN_OR_RETURN( + auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value))); + TF_ASSIGN_OR_RETURN( + auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); + auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0); + auto real_result = + ir_builder_->CreateFSub(ir_builder_->CreateFMul(exp_a, cos_b), one); + auto imag_result = ir_builder_->CreateFMul(exp_a, sin_b); + return EmitComposeComplex(op, real_result, imag_result); + } case HloOpcode::kCos: { // cos(z) = .5(e^(iz) + e^(-iz)) // cos(a+bi) = .5(e^(-b+ai) + e^(b-ai)) @@ -975,6 +1009,28 @@ StatusOr ElementalIrEmitter::EmitLog(PrimitiveType prim_type, {value->getType()}, ir_builder_); } +StatusOr ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) const { + auto x = value; + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto one = llvm::ConstantFP::get(type, 1.0); + auto negative_half = llvm::ConstantFP::get(type, -0.5); + // When x is large, the naive evaluation of ln(x + 1) is more + // accurate than the Taylor series. + TF_ASSIGN_OR_RETURN(auto for_large_x, + EmitLog(prim_type, ir_builder_->CreateFAdd(x, one))); + // The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + …. + auto for_small_x = ir_builder_->CreateFMul( + ir_builder_->CreateFAdd(ir_builder_->CreateFMul(negative_half, x), one), + x); + const auto kAntilogarithmIsSmallThreshold = 1e-4; + auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, + {type}, ir_builder_); + auto x_is_small = ir_builder_->CreateFCmpOLT( + abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold)); + return ir_builder_->CreateSelect(x_is_small, for_small_x, for_large_x); +} + StatusOr ElementalIrEmitter::EmitSin(PrimitiveType prim_type, llvm::Value* value) const { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, @@ -993,6 +1049,29 @@ StatusOr ElementalIrEmitter::EmitExp(PrimitiveType prim_type, {value->getType()}, ir_builder_); } +StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) const { + auto x = value; + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto one = llvm::ConstantFP::get(type, 1.0); + auto half = llvm::ConstantFP::get(type, 0.5); + // When the exponent is large, the naive evaluation of e^(x) - 1 is more + // accurate than the Taylor series. + TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value)); + auto for_large_x = ir_builder_->CreateFSub(exp_x, one); + // The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + …. + // We want exp(x)-1 which is x + x^2/2 + x^3/6 + …. + auto x_squared = ir_builder_->CreateFAdd(x, x); + auto x_squared_over_two = ir_builder_->CreateFMul(x_squared, half); + auto for_small_x = ir_builder_->CreateFAdd(x, x_squared_over_two); + const auto kExponentIsSmallThreshold = 1e-5; + auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, + {type}, ir_builder_); + auto x_is_small = ir_builder_->CreateFCmpOLT( + abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold)); + return ir_builder_->CreateSelect(x_is_small, for_small_x, for_large_x); +} + StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { @@ -1468,6 +1547,26 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, operand_to_generator.at(hlo->operand(1))(dim_index)); + + // Clamp the start index so that the sliced portion fits in the operand: + // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value, + index[i]->getType()); + llvm::Value* operand_dim_size = llvm::ConstantInt::get( + start_index_value->getType(), input_hlo->shape().dimensions(i)); + llvm::Value* output_dim_size = llvm::ConstantInt::get( + start_index_value->getType(), hlo->shape().dimensions(i)); + + start_index_value = EmitIntegralMin( + ir_builder_->CreateSub(operand_dim_size, output_dim_size), + EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0), + start_index_value, /*is_signed=*/true), + /*is_signed=*/true); + start_index_value->setName( AsStringRef(IrName(hlo, StrCat("start_idx", i)))); slice_start_index[i] = start_index_value; @@ -1476,14 +1575,8 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( llvm_ir::IrArray::Index input_index(rank); for (int64 i = 0; i < rank; ++i) { // Emit IR which computes: - // input_index = (start_index + offset_index) % dim_size - // Security note: this is the code that keeps the indices in-bounds. - llvm::Value* dim_size = llvm::ConstantInt::get( - index[i]->getType(), input_hlo->shape().dimensions(i)); - llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast( - slice_start_index[i], index[i]->getType()); - input_index[i] = ir_builder_->CreateURem( - ir_builder_->CreateAdd(start_index, index[i]), dim_size); + // input_index = start_index + offset_index + input_index[i] = ir_builder_->CreateAdd(slice_start_index[i], index[i]); } return operand_to_generator.at(input_hlo)(input_index); } @@ -1582,104 +1675,48 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( const int64 rank = ShapeUtil::Rank(input_hlo->shape()); llvm_ir::IrArray::Index slice_start_index(rank); llvm_ir::IrArray::Index slice_limit_index(rank); - // Slice starts at update[index - slice_start_index_adjusted], - // where adjusted value = slice_start_index when in bounds, and - // adjusted value = slice_start_index - input_dim, when wrapping. - llvm_ir::IrArray::Index slice_start_index_adjusted(rank); - // Slice intersection gathers (ANDs) conditions on all ranks for which // 'input' is set to 'update' llvm::Value* slice_intersection = ir_builder_->getTrue(); for (int64 i = 0; i < rank; ++i) { - // Emit IR to read dynamic start indices from 'start_hlo'. llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, operand_to_generator.at(start_hlo)(dim_index)); - start_index_value->setName( - AsStringRef(IrName(hlo, StrCat("start_idx", i)))); - slice_start_index[i] = ir_builder_->CreateZExtOrBitCast( - start_index_value, index[i]->getType()); + // Clamp the start index so that the update region fits in the operand. + // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value, + index[i]->getType()); llvm::Value* input_dim_size = llvm::ConstantInt::get( index[i]->getType(), input_hlo->shape().dimensions(i)); llvm::Value* update_dim_size = llvm::ConstantInt::get( index[i]->getType(), update_hlo->shape().dimensions(i)); - // Generate code to handle wrapping semantics: - // slice_start_index[i] = slice_start_index[i] % input_dim_size; - // slice_limit_index[i] = slice_start_index[i] + update_dim_size. - // slice_start_index[i] is updated in place and it will now be in - // range. slice_limit_index[i] may be out of range, and it's being - // URem-ed below if so. - slice_start_index[i] = - ir_builder_->CreateURem(slice_start_index[i], input_dim_size); + start_index_value = EmitIntegralMin( + ir_builder_->CreateSub(input_dim_size, update_dim_size), + EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0), + start_index_value, /*is_signed=*/true), + /*is_signed=*/true); + + start_index_value->setName( + AsStringRef(IrName(hlo, StrCat("start_idx", i)))); + slice_start_index[i] = start_index_value; slice_limit_index[i] = ir_builder_->CreateAdd(slice_start_index[i], update_dim_size); - // Test if slice_limit_index[i] is in bounds - llvm::Value* in_bounds = - ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size); - llvm_ir::LlvmIfData if_in_bounds = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); - - // Handle true BB (slice_limit_index[i] <= input_dim_size). - SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_); - // Check that index[i] >= slice_start_index[i] && - // index[i] < slice_limit_index[i] - llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd( + slice_intersection = ir_builder_->CreateAnd( slice_intersection, ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), - "slice_intersection_in"); - slice_intersection_in_bounds = ir_builder_->CreateAnd( - slice_intersection_in_bounds, + "slice_intersection"); + slice_intersection = ir_builder_->CreateAnd( + slice_intersection, ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]), - "slice_intersection_in"); - - // Handle false BB (slice_limit_index[i] > input_dim_size). - SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_); - // Check that index[i] >= slice_start_index[i] || - // index[i] < slice_limit_index[i]%input_dim_size. - llvm::Value* index_wraps = ir_builder_->CreateICmpSLT( - index[i], - ir_builder_->CreateURem(slice_limit_index[i], input_dim_size)); - llvm::Value* slice_intersection_or = ir_builder_->CreateOr( - ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), index_wraps, - "slice_intersection_out"); - llvm::Value* slice_intersection_out_of_bounds = ir_builder_->CreateAnd( - slice_intersection, slice_intersection_or, "slice_intersection_out"); - // Create value for slice_start_index_adjusted[i] when out of bounds. - // If within out-of-bounds if. - llvm_ir::LlvmIfData if_start_needs_adjustment = - llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_); - SetToFirstInsertPoint(if_start_needs_adjustment.true_block, ir_builder_); - llvm::Value* slice_start_index_adjusted_oob = - ir_builder_->CreateSub(slice_start_index[i], input_dim_size); - SetToFirstInsertPoint(if_start_needs_adjustment.after_block, ir_builder_); - llvm::PHINode* slice_start_index_adjusted_phi = - ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(), 2); - slice_start_index_adjusted_phi->addIncoming( - slice_start_index_adjusted_oob, if_start_needs_adjustment.true_block); - slice_start_index_adjusted_phi->addIncoming( - slice_start_index[i], if_start_needs_adjustment.false_block); - // End of if within if. - - // After checking in/out of bounds. - SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_); - llvm::PHINode* phi_slice_intersection = - ir_builder_->CreatePHI(slice_intersection->getType(), 2); - phi_slice_intersection->addIncoming(slice_intersection_in_bounds, - if_in_bounds.true_block); - phi_slice_intersection->addIncoming(slice_intersection_out_of_bounds, - if_start_needs_adjustment.after_block); - slice_intersection = phi_slice_intersection; - - llvm::PHINode* phi_index = - ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2); - phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block); - phi_index->addIncoming(slice_start_index_adjusted_phi, - if_start_needs_adjustment.after_block); - slice_start_index_adjusted[i] = phi_index; + "slice_intersection"); } // Emit: @@ -1696,12 +1733,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // Compute update index for intersection case. llvm_ir::IrArray::Index update_index(rank); for (int64 i = 0; i < rank; ++i) { - llvm::Value* update_dim_size = llvm::ConstantInt::get( - index[i]->getType(), update_hlo->shape().dimensions(i)); - // NOTE: Subtraction will be positive due to bounds checking above. - update_index[i] = ir_builder_->CreateURem( - ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]), - update_dim_size); + update_index[i] = ir_builder_->CreateSub(index[i], slice_start_index[i]); } TF_ASSIGN_OR_RETURN(llvm::Value * true_value, operand_to_generator.at(update_hlo)(update_index)); @@ -1784,8 +1816,13 @@ StatusOr ElementalIrEmitter::EmitElementalDot( const llvm_ir::IrArray::Index& dot_result_index) const { auto lhs_generator = operand_to_generator.at(hlo->operand(0)); auto rhs_generator = operand_to_generator.at(hlo->operand(1)); - int64 contracted_dim_size = hlo->operand(0)->shape().dimensions( - hlo->operand(0)->shape().dimensions_size() - 1); + + const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers(); + int64 lhs_contracting_dim = dim_numbers.lhs_contracting_dimensions(0); + int64 rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0); + + int64 contracted_dim_size = + hlo->operand(0)->shape().dimensions(lhs_contracting_dim); int64 lhs_dims = hlo->operand(0)->shape().dimensions_size(); int64 rhs_dims = hlo->operand(1)->shape().dimensions_size(); @@ -1816,13 +1853,12 @@ StatusOr ElementalIrEmitter::EmitElementalDot( for (int64 i = 0; i < lhs_dims - 1; i++) { lhs_index.push_back(dot_result_index[i]); } - lhs_index.push_back(inner_loop->GetIndVarValue()); + lhs_index.InsertAt(lhs_contracting_dim, inner_loop->GetIndVarValue()); - for (int64 i = 0; i < rhs_dims - 2; i++) { + for (int64 i = 0; i < rhs_dims - 1; i++) { rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]); } - rhs_index.push_back(inner_loop->GetIndVarValue()); - rhs_index.push_back(dot_result_index.back()); + rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue()); llvm::Value* current_accumulator = ir_builder_->CreateLoad(accumulator_alloca); @@ -1877,10 +1913,12 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kReal: diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 26dff0d96f1d0f00fcdf12410a3769d18a186773..d199473374ad394913413a7d3fe805f8782936f7 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -105,6 +105,9 @@ class ElementalIrEmitter { virtual StatusOr EmitLog(PrimitiveType prim_type, llvm::Value* value) const; + virtual StatusOr EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) const; + virtual StatusOr EmitSin(PrimitiveType prim_type, llvm::Value* value) const; @@ -114,6 +117,9 @@ class ElementalIrEmitter { virtual StatusOr EmitExp(PrimitiveType prim_type, llvm::Value* value) const; + virtual StatusOr EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) const; + virtual StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const; diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b43dc0c65d9b6e7c05e06010ba2ff2eb27392295 --- /dev/null +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace { + +using tensorflow::gtl::nullopt; + +class ElementalIrEmitterExecutionTest : public HloTestBase { + protected: + void RunTest(const string& hlo_text, + tensorflow::gtl::ArraySlice args) { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_text, config)); + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), args, nullopt)); + } +}; + +XLA_TEST_F(ElementalIrEmitterExecutionTest, DotFusion) { + const string hlo_text = R"( +HloModule FusedDot + +fused_computation { + arg0 = s32[1,2,1]{2,1,0} parameter(0) + reshape.lhs = s32[2,1]{1,0} reshape(arg0) + arg1 = s32[1,2,1]{2,1,0} parameter(1) + reshape.rhs = s32[2,1]{1,0} reshape(arg1) + ROOT dot = s32[1,1]{1,0} dot(reshape.lhs, reshape.rhs), lhs_contracting_dims={0}, rhs_contracting_dims={0} +} + +ENTRY main { + entry_arg0 = s32[1,2,1]{2,1,0} parameter(0) + entry_arg1 = s32[1,2,1]{2,1,0} parameter(1) + ROOT fusion = s32[1,1]{1,0} fusion(entry_arg0, entry_arg1), kind=kLoop, calls=fused_computation +} +)"; + + std::unique_ptr lhs = Literal::CreateR3({{{1}, {2}}}); + std::unique_ptr rhs = Literal::CreateR3({{{3}, {4}}}); + RunTest(hlo_text, {lhs.get(), rhs.get()}); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index 2f0b9ed2bd98fbea4e67c0a30d5aa41ff6a06979..6794cfe297b0fb9a15eb9b7e6906d225f9597d07 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -37,11 +37,11 @@ AsyncExecution::AsyncExecution(Backend* backend, } } -tensorflow::Status AsyncExecution::BlockUntilDone() const { +Status AsyncExecution::BlockUntilDone() const { for (auto& stream : streams_) { TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); } - return tensorflow::Status::OK(); + return Status::OK(); } ExecutionTracker::ExecutionTracker() : next_handle_(1) {} @@ -61,7 +61,7 @@ ExecutionHandle ExecutionTracker::Register( return execution_handle; } -tensorflow::Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { +Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { tensorflow::mutex_lock lock(execution_mutex_); auto it = handle_to_execution_.find(handle.handle()); if (it == handle_to_execution_.end()) { @@ -69,7 +69,7 @@ tensorflow::Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { handle.handle()); } handle_to_execution_.erase(handle.handle()); - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr ExecutionTracker::Resolve( diff --git a/tensorflow/compiler/xla/service/execution_tracker.h b/tensorflow/compiler/xla/service/execution_tracker.h index 5b6bddf9f16a85f7863f4d05c39c7d4c99209af1..4458152dd9a98890fc3a3e7f324245ec68821467 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.h +++ b/tensorflow/compiler/xla/service/execution_tracker.h @@ -43,7 +43,7 @@ class AsyncExecution { AsyncExecution(Backend* backend, std::vector streams, const ExecutionProfile& profile, GlobalDataHandle result); - tensorflow::Status BlockUntilDone() const; + Status BlockUntilDone() const; const GlobalDataHandle& result() const { return result_; } @@ -77,7 +77,7 @@ class ExecutionTracker { GlobalDataHandle data); // Unregisters the execution for the given handle. - tensorflow::Status Unregister(const ExecutionHandle& handle); + Status Unregister(const ExecutionHandle& handle); // Resolves the given ExecutionHandle to an AsyncExecution. Returns an // error status if the given handle is not found, which means that the diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index ddb687314ee8221ba9282f230db498b3a5d23499..5ee67ccb4ae147683c7b41941670c6fc413a0d09 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -89,7 +89,7 @@ GenericTransferManager::TransferLiteralFromDevice( } Status GenericTransferManager::TransferLiteralToDevice( - se::StreamExecutor* executor, const Literal& literal, + se::StreamExecutor* executor, const LiteralSlice& literal, const ShapedBuffer& device_buffer) { const Shape& shape = literal.shape(); VLOG(2) << "transferring literal shape to device: " @@ -115,7 +115,7 @@ Status GenericTransferManager::TransferLiteralToDevice( TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == device_memory.size()); // Element is array-shaped: transfer array data to device buffer. - const auto subliteral = LiteralView::Create(literal, index); + const auto subliteral = LiteralSlice(literal, index); std::unique_ptr relayed_out_literal; const void* source; if (LayoutUtil::Equal(device_subshape.layout(), @@ -137,7 +137,7 @@ Status GenericTransferManager::TransferLiteralToDevice( } Status GenericTransferManager::TransferLiteralToInfeed( - se::StreamExecutor* executor, const Literal& literal) { + se::StreamExecutor* executor, const LiteralSlice& literal) { return Unimplemented("Generic transfer to Infeed"); } diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 0579099de40ba3e43f69e4e6474b56691064c692..3da9570ef7eebcdf618439f628fb4d5589993e4f 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -45,11 +45,11 @@ class GenericTransferManager : public TransferManager { se::StreamExecutor* executor, const ShapedBuffer& device_buffer) override; Status TransferLiteralToDevice(se::StreamExecutor* executor, - const Literal& literal, + const LiteralSlice& literal, const ShapedBuffer& device_buffer) override; Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) override; + const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, Literal* literal) override; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 7cb7f550730eeb562a6163cf5499ffaaf02d3327..4012f87f2bf69d1ab056da5d6c750441c7404980 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -291,6 +291,7 @@ cc_library( "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform/default/build_config:cublas_plugin", "//tensorflow/core/platform/default/build_config:cudnn_plugin", @@ -388,8 +389,10 @@ cc_library( deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:instruction_fusion", + "//tensorflow/compiler/xla/service:pattern_matcher", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index 837f05244f7a8c71483cc30eeac9e1c219e6bbd2..ab5149dcdb09290cd0c0b2233029d0988a95f036 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -37,11 +37,11 @@ void BufferAllocations::Builder::RegisterBuffer(BufferAllocation::Index index, } StatusOr> BufferAllocations::Builder::Build( - const BufferAssignment& buffer_assignment, int device_ordinal, + const BufferAssignment* buffer_assignment, int device_ordinal, DeviceMemoryAllocator* memory_allocator) { - const int64 num_buffers = buffer_assignment.Allocations().size(); - auto buffer_allocations = WrapUnique( - new BufferAllocations(num_buffers, device_ordinal, memory_allocator)); + const int64 num_buffers = buffer_assignment->Allocations().size(); + auto buffer_allocations = WrapUnique(new BufferAllocations( + num_buffers, device_ordinal, memory_allocator, buffer_assignment)); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { // If buffer #i's address is already registered (e.g. external arguments or @@ -62,28 +62,28 @@ StatusOr> BufferAllocations::Builder::Build( // Allocate each allocation that might escape, or is the temp buffer. bool seen_temp_buffer = false; - const BufferAllocation& allocation = buffer_assignment.GetAllocation(i); + const BufferAllocation& allocation = buffer_assignment->GetAllocation(i); if (allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer()) { const int64 buffer_size = allocation.size(); se::DeviceMemoryBase buffer_address; if (buffer_size > 0) { - TF_ASSIGN_OR_RETURN(buffer_address, memory_allocator->Allocate( - device_ordinal, buffer_size)); - if (buffer_address == nullptr) { - return ResourceExhausted( - "Out of memory when allocating %s for buffer %lld.", - tensorflow::strings::HumanReadableNumBytes(buffer_size).c_str(), - i); - } - if (reinterpret_cast(buffer_address.opaque()) % + OwningDeviceMemory buffer; + TF_ASSIGN_OR_RETURN( + buffer, memory_allocator->Allocate(device_ordinal, buffer_size)); + if (reinterpret_cast(buffer.opaque()) % kCudaMallocAlignBytes != 0) { return InternalError( "Address returned by memory_allocator->Allocate must be a " "multiple of %llx, but was %p", - kCudaMallocAlignBytes, buffer_address.opaque()); + kCudaMallocAlignBytes, buffer.opaque()); } + // We do manual memory management within BufferAllocations. Be sure not + // to do a TF_RETURN_IF_ERROR between this line and the + // buffer_allocations->SetBuffer(buffer_address) call below! + buffer_address = buffer.Forget(); } + buffer_allocations->SetBuffer(i, buffer_address); if (allocation.IsPreallocatedTempBuffer()) { if (seen_temp_buffer) { @@ -103,28 +103,42 @@ StatusOr> BufferAllocations::Builder::Build( << "B)"; } } - return std::move(buffer_allocations); } -tensorflow::Status BufferAllocations::TearDown( - const std::set& live_addresses, - const BufferAssignment& buffer_assignment) { - // Deallocate temporary buffers. - const int64 num_buffers = buffer_assignment.Allocations().size(); +BufferAllocations::~BufferAllocations() { + if (!torn_down_) { + // Presumably if we're executing this branch, the caller is in an error + // state, otherwise it would have explicitly called TearDown so it could + // save some set of live addresses. So ignoring any errors in TearDown is + // sensible. + TearDown(/*live_addresses=*/{}).IgnoreError(); + } +} + +Status BufferAllocations::TearDown( + const std::set& live_addresses) { + // Deallocate temporary buffers, taking care to try to deallocate all of them + // even if one of the deallocations fails. + Status status; + const int64 num_buffers = buffer_assignment_->Allocations().size(); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { - const BufferAllocation& allocation = buffer_assignment.GetAllocation(i); + const BufferAllocation& allocation = buffer_assignment_->GetAllocation(i); se::DeviceMemoryBase buffer_address = GetDeviceAddress(allocation.index()); // Deallocate buffers marked "maybe_live_out" but aren't actually live out, // and temp buffers. if ((allocation.maybe_live_out() && !live_addresses.count(buffer_address)) || allocation.IsPreallocatedTempBuffer()) { - TF_RETURN_IF_ERROR( - memory_allocator_->Deallocate(device_ordinal_, &buffer_address)); + auto dealloc_result = + memory_allocator_->Deallocate(device_ordinal_, buffer_address); + if (!dealloc_result.ok() && status.ok()) { + status = dealloc_result; + } } } - return tensorflow::Status::OK(); + torn_down_ = true; + return status; } se::DeviceMemoryBase BufferAllocations::GetDeviceAddress( diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index c2fc35be4ca4bc6df85ee21fb6b564bfd6de3ec0..636623502597b3a66523938ba430e9d5a82f796c 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -48,13 +48,15 @@ class BufferAllocations { // `device_ordinal` is the number of the device this function allocates // memory on. StatusOr> Build( - const BufferAssignment& buffer_assignment, int device_ordinal, + const BufferAssignment* buffer_assignment, int device_ordinal, DeviceMemoryAllocator* memory_allocator); private: std::map registered_buffers_; }; + ~BufferAllocations(); + BufferAllocations(const BufferAllocations&) = delete; BufferAllocations& operator=(const BufferAllocations&) = delete; @@ -76,16 +78,16 @@ class BufferAllocations { // Tears down all buffers allocated by this object that are not in // `live_addresses`. - tensorflow::Status TearDown( - const std::set& live_addresses, - const BufferAssignment& buffer_assignment); + Status TearDown(const std::set& live_addresses); private: BufferAllocations(BufferAllocation::Index buffer_count, int device_ordinal, - DeviceMemoryAllocator* memory_allocator) + DeviceMemoryAllocator* memory_allocator, + const BufferAssignment* buffer_assignment) : buffers_(buffer_count), device_ordinal_(device_ordinal), - memory_allocator_(memory_allocator) {} + memory_allocator_(memory_allocator), + buffer_assignment_(buffer_assignment) {} // Sets the device address of buffer `buffer_index`. void SetBuffer(BufferAllocation::Index buffer_index, @@ -100,8 +102,9 @@ class BufferAllocations { se::DeviceMemoryBase temp_buffer_base_; int device_ordinal_; - DeviceMemoryAllocator* memory_allocator_; + const BufferAssignment* buffer_assignment_; + bool torn_down_ = false; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index dce8de2e301ecfaa4674b8be48b8c02bfabf3f4b..77a48965e031349b045a956fd3f28c58607328e5 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -35,9 +35,10 @@ ConditionalThunk::ConditionalThunk( true_thunk_(std::move(true_thunk_sequence), hlo), false_thunk_(std::move(false_thunk_sequence), hlo) {} -Status ConditionalThunk::Initialize(const GpuExecutable& executable) { - TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable)); - TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable)); +Status ConditionalThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { + TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable, executor)); + TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable, executor)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h index e40872688fdad24d24db5f7cebb3206c77652dce..ee03865d174469285a9e98b8a30fea90d997df37 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h @@ -47,7 +47,8 @@ class ConditionalThunk : public Thunk { ConditionalThunk(const ConditionalThunk&) = delete; ConditionalThunk& operator=(const ConditionalThunk&) = delete; - Status Initialize(const GpuExecutable& executable) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream) override; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 64d3b84b8c73d82800270aebcebf7f7a8fec5fe4..f0881124128c9b043392ffc4fa3aee2cd5b754c7 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -29,11 +29,6 @@ namespace xla { namespace gpu { using se::dnn::AlgorithmDesc; -using se::dnn::BatchDescriptor; -using se::dnn::ConvolutionDescriptor; -using se::dnn::DataLayout; -using se::dnn::FilterDescriptor; -using se::dnn::FilterLayout; ConvolutionThunk::ConvolutionThunk( CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc index bf912fbd14de5874062a79db28186ab233575233..ee38c0318a878c7bcdc02afdcd146bfb4498d9a2 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc @@ -29,12 +29,12 @@ HostToDeviceCopyThunk::HostToDeviceCopyThunk( destination_buffer_(destination_buffer), mem_size_(mem_size) {} -tensorflow::Status HostToDeviceCopyThunk::ExecuteOnStream( +Status HostToDeviceCopyThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); stream->ThenMemcpy(&destination_data, source_address_, mem_size_); - return tensorflow::Status::OK(); + return Status::OK(); } DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( @@ -46,14 +46,14 @@ DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( destination_buffer_(destination_buffer), mem_size_(mem_size) {} -tensorflow::Status DeviceToDeviceCopyThunk::ExecuteOnStream( +Status DeviceToDeviceCopyThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); se::DeviceMemoryBase source_data = buffer_allocations.GetDeviceAddress(source_buffer_); stream->ThenMemcpy(&destination_data, source_data, mem_size_); - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h index 2e7eb5f3445bc9294fa455ef31fb816cdba4726c..8b128386f61636de9ac41e856a2b00c578e05735 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h @@ -39,8 +39,8 @@ class HostToDeviceCopyThunk : public Thunk { HostToDeviceCopyThunk(const HostToDeviceCopyThunk&) = delete; HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const void* source_address_; @@ -62,8 +62,8 @@ class DeviceToDeviceCopyThunk : public Thunk { DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete; DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const BufferAllocation::Slice source_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index c4c56c56928810d043085f284cda351391195c3b..6a46bdb9b438f81dc564b9033f5d302f90b6a997 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -35,35 +35,22 @@ class ScratchAllocator : public se::ScratchAllocator { ScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator) : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} - ~ScratchAllocator() override; - int64 GetMemoryLimitInBytes(se::Stream* stream) override { return 1LL << 32; // 4GB. TODO(jlebar): Tune this? } int64 TotalAllocatedBytes() { return total_allocated_bytes_; } - se::port::StatusOr> AllocateBytes( - se::Stream* stream, int64 byte_size) override; + StatusOr> AllocateBytes(se::Stream* stream, + int64 byte_size) override; private: const int device_ordinal_; DeviceMemoryAllocator* memory_allocator_; - std::vector allocated_buffers_; + std::vector allocated_buffers_; int64 total_allocated_bytes_ = 0; }; -ScratchAllocator::~ScratchAllocator() { - for (auto& allocated_buffer : allocated_buffers_) { - if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) - .ok()) { - // The program can still continue with failed deallocation. - LOG(ERROR) << "Failed to deallocate the allocated buffer: " - << allocated_buffer.opaque(); - } - } -} - -se::port::StatusOr> ScratchAllocator::AllocateBytes( +StatusOr> ScratchAllocator::AllocateBytes( se::Stream* stream, int64 byte_size) { CHECK_GE(byte_size, 0) << "byte_size must be positive."; if (byte_size > GetMemoryLimitInBytes(stream)) { @@ -74,19 +61,14 @@ se::port::StatusOr> ScratchAllocator::AllocateBytes( byte_size, GetMemoryLimitInBytes(stream))); } - auto status_or_memory = - memory_allocator_->Allocate(device_ordinal_, byte_size, - /*retry_on_failure=*/false); - if (!status_or_memory.ok()) { - return se::port::Status(se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Failed to allocate %lld bytes on device %d.", - byte_size, device_ordinal_)); - } - se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); - allocated_buffers_.push_back(allocated_buffer); + TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer, + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false)); total_allocated_bytes_ += byte_size; - return se::DeviceMemory(allocated_buffer); + + se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase(); + allocated_buffers_.push_back(std::move(allocated_buffer)); + return se::DeviceMemory(buffer_addr); } // Determines whether we can safely perform a winograd non-fused convolution for @@ -197,22 +179,42 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( // We don't put any data in these buffers, because (in theory, anyway) the // speed of a conv isn't affected by the data being convolved. ScratchAllocator input_output_allocator(device_ordinal, allocator); - se::port::StatusOr input_buf = + StatusOr maybe_input_buf = input_output_allocator.AllocateBytes(&stream, ShapeUtil::ByteSizeOf(input_shape)); - se::port::StatusOr filter_buf = + StatusOr maybe_filter_buf = input_output_allocator.AllocateBytes(&stream, ShapeUtil::ByteSizeOf(filter_shape)); - se::port::StatusOr output_buf = + StatusOr maybe_output_buf = input_output_allocator.AllocateBytes(&stream, ShapeUtil::ByteSizeOf(output_shape)); - if (!input_buf.ok() || !filter_buf.ok() || !output_buf.ok()) { + if (!maybe_input_buf.ok() || !maybe_filter_buf.ok() || + !maybe_output_buf.ok()) { LOG(WARNING) << "Couldn't allocate space for input/filter/output of convolution " << instr->ToString() << ". Falling back to default algorithm."; return nullopt; } + DeviceMemoryBase input_buf = maybe_input_buf.ValueOrDie(); + DeviceMemoryBase filter_buf = maybe_filter_buf.ValueOrDie(); + DeviceMemoryBase output_buf = maybe_output_buf.ValueOrDie(); + + // Although we don't have evidence this matters, zero out the buffers before + // autotuning. It's conceivable that using uninitialized memory as the inputs + // might affect performance if e.g. the inputs contain denormals, and this is + // easy enough. + if (!stream.ThenMemZero(&input_buf, input_buf.size()) + .ThenMemZero(&filter_buf, filter_buf.size()) + .ThenMemZero(&output_buf, output_buf.size()) + .BlockHostUntilDone() + .ok()) { + LOG(WARNING) + << "Couldn't zero out input/filter/output buffer for convolution " + << instr->ToString() << ". Falling back to default algorithm."; + return nullopt; + } + const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo( input_shape, output_shape, dnums, stream_exec_); se::dnn::ProfileResult best_result; @@ -225,12 +227,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); - bool launch_ok = RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - input_buf.ValueOrDie(), filter_buf.ValueOrDie(), - output_buf.ValueOrDie(), &scratch_allocator, window, - dnums, AlgorithmConfig(alg), &stream, &profile_result) - .ok(); + bool launch_ok = + RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + input_buf, filter_buf, output_buf, + &scratch_allocator, window, dnums, + AlgorithmConfig(alg), &stream, &profile_result) + .ok(); if (launch_ok && profile_result.is_valid()) { int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 5af7a77ea858563fbea05af8efd54f96a74aee93..e5e2a0478a0659986ddec8d6785827b14b9efb56 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -227,6 +227,11 @@ StatusOr GpuElementalIrEmitter::EmitLog( return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type); } +StatusOr GpuElementalIrEmitter::EmitLog1p( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_log1p", {value}, {prim_type}, prim_type); +} + StatusOr GpuElementalIrEmitter::EmitSin( PrimitiveType prim_type, llvm::Value* value) const { return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type); @@ -242,6 +247,11 @@ StatusOr GpuElementalIrEmitter::EmitExp( return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type); } +StatusOr GpuElementalIrEmitter::EmitExpm1( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_expm1", {value}, {prim_type}, prim_type); +} + StatusOr GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 77d4569b1e8e398005e8f517ff086a77aedd382d..91f4d960aa62fff3e0699ece37a8c74d7dcf2f59 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -64,6 +64,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitLog(PrimitiveType prim_type, llvm::Value* value) const override; + StatusOr EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) const override; + StatusOr EmitSin(PrimitiveType prim_type, llvm::Value* value) const override; @@ -73,6 +76,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitExp(PrimitiveType prim_type, llvm::Value* value) const override; + StatusOr EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) const override; + StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const override; diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index cc747addbd152eb82b0b2ef92b8653fc861f97be..e14ee6918bf148861ecccac99355fccf7ae93103 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -31,23 +31,12 @@ FftScratchAllocator::FftScratchAllocator( int device_ordinal, DeviceMemoryAllocator* memory_allocator) : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} -FftScratchAllocator::~FftScratchAllocator() { - for (auto& allocated_buffer : allocated_buffers_) { - if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) - .ok()) { - // The program can still continue with failed deallocation. - LOG(ERROR) << "Failed to deallocate the allocated buffer: " - << allocated_buffer.opaque(); - } - } -} - int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) { constexpr int64 kFftScratchSize = 1LL << 32; // 4GB by default. return kFftScratchSize; } -se::port::StatusOr> FftScratchAllocator::AllocateBytes( +StatusOr> FftScratchAllocator::AllocateBytes( se::Stream* stream, int64 byte_size) { CHECK_GE(byte_size, 0) << "byte_size must be positive."; if (byte_size > GetMemoryLimitInBytes(stream)) { @@ -58,18 +47,14 @@ se::port::StatusOr> FftScratchAllocator::AllocateBytes( byte_size, GetMemoryLimitInBytes(stream))); } - auto status_or_memory = - memory_allocator_->Allocate(device_ordinal_, byte_size, - /*retry_on_failure=*/false); - if (!status_or_memory.ok()) { - return tensorflow::errors::ResourceExhausted( - "Failed to allocate %lld bytes on device %d.", byte_size, - device_ordinal_); - } - se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); - allocated_buffers_.push_back(allocated_buffer); + TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer, + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false)); total_allocated_bytes_ += byte_size; - return se::DeviceMemory(allocated_buffer); + + se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase(); + allocated_buffers_.push_back(std::move(allocated_buffer)); + return se::DeviceMemory(buffer_addr); } namespace { @@ -121,8 +106,8 @@ FftThunk::FftThunk(FftType fft_type, input_shape_(input_shape), output_shape_(output_shape) {} -tensorflow::Status FftThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { VLOG(3) << "FFT type: " << FftTypeToString(fft_type_); VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_); VLOG(3) << "Output shape: " @@ -222,7 +207,7 @@ tensorflow::Status FftThunk::ExecuteOnStream( LOG(FATAL) << "unsupported fft type"; } if (launch_ok) { - return tensorflow::Status::OK(); + return Status::OK(); } return InternalError("Unable to launch fft for thunk %p with type %s", this, FftTypeToString(fft_type_).c_str()); diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index 24b1dca99865fe21d0ca3af91e0d169f7b74a78a..b0a22564f3a09bb67a3c01723f6e37c604656d45 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -39,8 +39,6 @@ class FftScratchAllocator : public se::ScratchAllocator { FftScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator); - ~FftScratchAllocator() override; - int64 GetMemoryLimitInBytes(se::Stream* stream) override; int64 TotalAllocatedBytes() { return total_allocated_bytes_; } @@ -51,7 +49,7 @@ class FftScratchAllocator : public se::ScratchAllocator { private: const int device_ordinal_; DeviceMemoryAllocator* memory_allocator_; - std::vector allocated_buffers_; + std::vector allocated_buffers_; int64 total_allocated_bytes_ = 0; }; @@ -73,8 +71,8 @@ class FftThunk : public Thunk { FftThunk& operator=(const FftThunk&) = delete; // Cannot share fft_plan_ // Does the FFT for the thunk on "stream". - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const se::fft::Type fft_type_; diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index 6e6966df3987eef29b2122c3ef8f11b7cd0bfe14..b36539e0cb8d0a2f4758dd90acbdd8fc7181b8ca 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -30,19 +30,20 @@ ForThunk::ForThunk(const int64 loop_limit, body_thunk_sequence_( MakeUnique(std::move(*body_thunk_sequence), hlo)) {} -tensorflow::Status ForThunk::Initialize(const GpuExecutable& executable) { - TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable)); - return tensorflow::Status::OK(); +Status ForThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { + TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor)); + return Status::OK(); } -tensorflow::Status ForThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { for (int64 i = 0; i < loop_limit_; ++i) { // Invoke loop body thunk sequence. TF_RETURN_IF_ERROR( body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h index c78d1c50686297aea8235af928aba562697f49bc..41ddfe0ceb1d0516c1c64feca53212a925632209 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -36,9 +36,10 @@ class ForThunk : public Thunk { ForThunk(const ForThunk&) = delete; ForThunk& operator=(const ForThunk&) = delete; - tensorflow::Status Initialize(const GpuExecutable& executable) override; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const int64 loop_limit_; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 0ec12f52d8b398218ec370fc74bfdf6f97f43893..79fca43d022816645b8a07b9e806fe9cc3745e7c 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -215,14 +215,32 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { } } +DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) { + if (hlo_instruction.opcode() == HloOpcode::kDot) { + return hlo_instruction.dot_dimension_numbers(); + } + CHECK_EQ(hlo_instruction.opcode(), HloOpcode::kFusion); + CHECK_EQ(hlo_instruction.fusion_kind(), HloInstruction::FusionKind::kOutput); + CHECK_EQ(hlo_instruction.fused_expression_root()->opcode(), + HloOpcode::kMultiply); + // Try to find the dot inside the output fusion node. + const HloInstruction* dot = + hlo_instruction.fused_expression_root()->operand(0); + if (dot->opcode() != HloOpcode::kDot) { + dot = hlo_instruction.fused_expression_root()->operand(1); + } + CHECK_EQ(dot->opcode(), HloOpcode::kDot); + + return dot->dot_dimension_numbers(); +} + } // namespace GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape, bool transpose_lhs, - bool transpose_rhs, double alpha, + const Shape& output_shape, double alpha, const HloInstruction* hlo_instruction) : Thunk(Kind::kGemm, hlo_instruction), lhs_buffer_(lhs_buffer), @@ -231,12 +249,10 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, lhs_shape_(lhs_shape), rhs_shape_(rhs_shape), output_shape_(output_shape), - transpose_lhs_(transpose_lhs), - transpose_rhs_(transpose_rhs), alpha_(alpha) {} -tensorflow::Status GemmThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { VLOG(2) << "Executing a GemmThunk"; se::DeviceMemoryBase lhs_data = @@ -284,10 +300,12 @@ tensorflow::Status GemmThunk::ExecuteOnStream( shape.dimensions(!is_row_major)); }; - const MatrixDescriptor lhs_descriptor = - make_descriptor(lhs_data, lhs_shape_, transpose_lhs_); - const MatrixDescriptor rhs_descriptor = - make_descriptor(rhs_data, rhs_shape_, transpose_rhs_); + DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); + + const MatrixDescriptor lhs_descriptor = make_descriptor( + lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == 0); + const MatrixDescriptor rhs_descriptor = make_descriptor( + rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == 1); // Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to // autotune this gemm to figure out the best algorithm. @@ -350,7 +368,7 @@ tensorflow::Status GemmThunk::ExecuteOnStream( if (!launch_ok) { return InternalError("Unable to launch cuBLAS gemm on stream %p", stream); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index a18f425bc38fd3fbbb345901514c4ac16dbe97ec..7a4830d64e7caef5a1170cbdbf8ab373fdaf16e2 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -35,22 +35,20 @@ namespace gpu { class GemmThunk : public Thunk { public: // Constructs a thunk that computes "output = (lhs rhs) * alpha" using - // BLAS gemm. transpose_lhs and transpose_rhs indicate whether gemm should - // transpose the lhs and rhs operand. hlo_instruction is as in Thunk. alpha is - // a constant. + // BLAS gemm. hlo_instruction is as in Thunk. alpha is a constant. GemmThunk(const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape, bool transpose_lhs, bool transpose_rhs, - double alpha, const HloInstruction* hlo_instruction); + const Shape& output_shape, double alpha, + const HloInstruction* hlo_instruction); GemmThunk(const GemmThunk&) = delete; GemmThunk& operator=(const GemmThunk&) = delete; // Does the gemm operation for the thunk on "stream", which must be non-null. - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; // Returns true if we'll perform autotuning if run on the given stream. If // so, we want the GPU to be quiescent during autotuning, so as not to @@ -69,8 +67,6 @@ class GemmThunk : public Thunk { const Shape rhs_shape_; const Shape output_shape_; - const bool transpose_lhs_; - const bool transpose_rhs_; const double alpha_; // Maps device names (StreamExecutor::DeviceDescription::name()) to autotune diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 4fdc4c89618bc0f179b2332373cb2fd3cf637390..d50153d8a31077e759bd6104d5bca8868a554fde 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -128,9 +128,8 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { } // Runs optimization passes on the given HLO module. -tensorflow::Status OptimizeHloModule(HloModule* hlo_module, - se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator) { +Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) { { HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker(); @@ -248,7 +247,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, { HloPassPipeline pipeline("layout_assignment"); pipeline.AddPass( - hlo_module->device_entry_computation_layout()); + hlo_module->mutable_device_entry_computation_layout()); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. @@ -283,12 +282,12 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); } } - return tensorflow::Status::OK(); + return Status::OK(); } // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. -tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { +Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // In some cases, we have to place the result of an instruction in a temporary // buffer. For instance, the buffer that holds an external parameter is // assumed immutable at this point, and should not be reused for output diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index 9db85bc788bde46c890a46ce9b0902ddce3f5675..d9560779f313d5a559c3eb0f5b28ff5dd210d9d5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -84,8 +84,13 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); } - } else if (ImplementedAsLibraryCall(*hlo)) { - // For all other library calls, materialize all the operands into memory. + } else if (ImplementedAsLibraryCall(*hlo) || + hlo->opcode() == HloOpcode::kCrossReplicaSum) { + // For all other library calls and cross-replica-sum, materialize all the + // operands into memory. (Cross-replica-sum gets its constant args + // materialized even if it's not implemented as a libcall to simplify the + // implementation. It's slower, but we can constant fold away constant + // args *anyway*, so we just need to make it work.) for (int64 i = 0; i < hlo->operand_count(); ++i) { TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 980cc89fa03abd874a8e0a694f2abb775c1de050..25d8f720ea4791a4c94efcad6909cd0c113fbe70 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -32,12 +32,15 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace gpu { namespace { +using tensorflow::tracing::ScopedAnnotation; + // A helper class for profiling HLO in the course of GPU program execution. // All of the profiling is guarded internally, to avoid the caller needing to // have lots of conditionals sprinkled around. @@ -134,9 +137,10 @@ Status GpuExecutable::ExecuteThunks( const BufferAllocations& buffer_allocations, bool block_host_until_done, HloExecutionProfile* hlo_execution_profile) { se::Stream* main_stream = run_options->stream(); + se::StreamExecutor* executor = main_stream->parent(); std::pair stream_compute_compatibility; - main_stream->parent()->GetDeviceDescription().cuda_compute_capability( + executor->GetDeviceDescription().cuda_compute_capability( &stream_compute_compatibility.first, &stream_compute_compatibility.second); TF_RET_CHECK(stream_compute_compatibility == compute_capability_) @@ -155,21 +159,39 @@ Status GpuExecutable::ExecuteThunks( sub_streams.reserve(thunk_schedule_->StreamCount() - 1); while (sub_streams.size() + 1 < thunk_schedule_->StreamCount()) { sub_streams.emplace_back(); - TF_ASSIGN_OR_RETURN( - sub_streams.back(), - run_options->BorrowStream(main_stream->parent()->device_ordinal())); + TF_ASSIGN_OR_RETURN(sub_streams.back(), + run_options->BorrowStream(executor->device_ordinal())); } HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream, sub_streams, hlo_module_->entry_computation()); uint64 start_micros = tensorflow::Env::Default()->NowMicros(); - // The next event enqueued on stream N must not run until the thunk at - // last_blocking_thunk_for_stream[N] completes. - std::map last_blocking_thunk_for_stream; + // This top-level trace serves two purposes: + // 1) It marks the scope of the whole XLA module. + // 2) It tells us whether tracing is enabled. We use this to avoid the + // expensive HloInstruction::ToString() calls inside the loop below if + // tracing is disabled. + ScopedAnnotation top_level_annotation(hlo_module_->name(), "XLA GPU module"); + std::map> thunk_to_finish_event; for (Thunk* thunk : thunk_schedule_->TotalOrder()) { - TF_RETURN_IF_ERROR(thunk->Initialize(*this)); + // Annotate execution of this op if tracing was enabled when we started + // running this module. If tracing is enabled *while* we're running the + // module, we won't get any data, but that's probably an OK trade-off. + // + // TODO(jlebar): Should we cache the results of HloInstruction::ToString(), + // since we expect it to be an expensive call? + tensorflow::gtl::optional op_annotation; + if (top_level_annotation.IsEnabled()) { + op_annotation.emplace( + thunk->hlo_instruction() != nullptr + ? thunk->hlo_instruction()->ToString(HloPrintOptions::Canonical()) + : "", + "XLA op"); + } + + TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor)); int32 stream_no = thunk_schedule_->StreamNumberForHlo(*thunk->hlo_instruction()); se::Stream* stream = @@ -179,18 +201,10 @@ Status GpuExecutable::ExecuteThunks( stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); } - if (last_blocking_thunk_for_stream.count(stream_no)) { - stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, - last_blocking_thunk_for_stream[stream_no]) - .get()); - last_blocking_thunk_for_stream.erase(stream_no); - } - // If this thunk requests it, wait for all currently-executing thunks to // finish. This is useful e.g. if the thunk is about to perform autotuning. if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) { TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone()); - last_blocking_thunk_for_stream.clear(); } profiler.StartOperation(); @@ -198,22 +212,11 @@ Status GpuExecutable::ExecuteThunks( << thunk->hlo_instruction()->ToString() << " on stream " << stream_no; TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); - if (thunk_schedule_->Depended(thunk) || thunk->ShouldBlockFutureThunks()) { + if (thunk_schedule_->Depended(thunk)) { auto finish_event = MakeUnique(main_stream->parent()); finish_event->Init(); stream->ThenRecordEvent(finish_event.get()); thunk_to_finish_event[thunk] = std::move(finish_event); - - if (thunk->ShouldBlockFutureThunks()) { - // Set last_blocking_thunk_for_stream on all streams other than this one - // so that all other streams will wait for this thunk to complete before - // executing any events that occur later in the total order. - for (int32 i = 0; i < sub_streams.size() + 1; ++i) { - if (i != stream_no) { - last_blocking_thunk_for_stream[i] = thunk; - } - } - } } profiler.FinishOperation(thunk->hlo_instruction()); } @@ -286,8 +289,8 @@ StatusOr GpuExecutable::ExecuteOnStream( se::StreamExecutor* executor = run_options->stream()->parent(); TF_ASSIGN_OR_RETURN( auto buffer_allocations, - buffer_allocations_builder.Build(*assignment_, executor->device_ordinal(), - memory_allocator)); + buffer_allocations_builder.Build( + assignment_.get(), executor->device_ordinal(), memory_allocator)); bool block_host_until_done = !memory_allocator->AllowsAsynchronousDeallocation(); @@ -329,8 +332,7 @@ StatusOr GpuExecutable::ExecuteOnStream( buffers_in_result.insert(src_base); return Status::OK(); })); - TF_RETURN_IF_ERROR( - buffer_allocations->TearDown(buffers_in_result, *assignment_)); + TF_RETURN_IF_ERROR(buffer_allocations->TearDown(buffers_in_result)); return std::move(shaped_buffer); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index 51aae79c3d8d0000007f9d2926d245de838d3aca..86a3a7111fd79494e469beecf3234f6cec9adb9c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -27,8 +27,7 @@ namespace gpu { // layout constraints for operands and results of library calls. class GpuLayoutAssignment : public LayoutAssignment { public: - explicit GpuLayoutAssignment( - const ComputationLayout& entry_computation_layout) + explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout) : LayoutAssignment(entry_computation_layout) {} ~GpuLayoutAssignment() override {} diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 7c801955943021def4ddc0accd9f318b7916ce93..4c45d2e94aebce5496da94841f6a1ae9015615c1 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -69,7 +69,7 @@ TEST_F(LayoutAssignmentTest, Elementwise) { *computation_layout.mutable_result_layout() = ShapeLayout(result_shape_with_layout); - GpuLayoutAssignment layout_assignment(computation_layout); + GpuLayoutAssignment layout_assignment(&computation_layout); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); for (const HloInstruction* operand : add->operands()) { @@ -156,7 +156,7 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { *computation_layout.mutable_result_layout() = ShapeLayout(result_shape); } - GpuLayoutAssignment layout_assignment(computation_layout); + GpuLayoutAssignment layout_assignment(&computation_layout); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -225,7 +225,7 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { {result_shape, offset_scale_shape, offset_scale_shape})); } - GpuLayoutAssignment layout_assignment(computation_layout); + GpuLayoutAssignment layout_assignment(&computation_layout); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -305,7 +305,7 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { {result_shape, scale_shape, scale_shape})); } - GpuLayoutAssignment layout_assignment(computation_layout); + GpuLayoutAssignment layout_assignment(&computation_layout); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first and fourth operands to the batchnorm call should have the diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index f13727ca9b6954f6be9b9277018fcc64ee326954..7bb8df6581b49b1bf8c84a972f715e8dc119d8de 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -44,8 +44,8 @@ GpuTransferManager::GpuTransferManager() /*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout) .getPointerSize(0 /* default address space */)) {} -Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) { +Status GpuTransferManager::TransferLiteralToInfeed( + se::StreamExecutor* executor, const LiteralSlice& literal) { const Shape& shape = literal.shape(); VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h index d040a99975230578c270deabdfe60c61649e778c..09f8227f508a3159f3def285898e15bfad544552 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h @@ -37,7 +37,7 @@ class GpuTransferManager : public GenericTransferManager { ~GpuTransferManager() override {} Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) override; + const LiteralSlice& literal) override; Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, const void* source) override; diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc index ece9fa04dce3fd12713fb7e58097dc16ebba83df..e230d538cc2df826778e8d13eaaaf31ec81c57f0 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc @@ -42,6 +42,15 @@ class HloScheduleTest : public HloTestBase { .ConsumeValueOrDie(); } + std::unique_ptr CreateNewModule() { + HloModuleConfig config; + auto debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_disable_multi_streaming(false); + config.set_debug_options(debug_options); + return MakeUnique("test_module", VersionedComputationHandle(), + config); + } + HloVec RemoveHlo(const HloVec& input, const std::unordered_set& remove) { HloVec result(input); @@ -65,9 +74,9 @@ TEST_F(HloScheduleTest, SequentialMatMul) { HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z)); + HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -193,11 +202,11 @@ TEST_F(HloScheduleTest, ConcurrentMatMul) { HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x)); + HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); HloInstruction* add = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); + HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, dot2)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(add)); @@ -259,24 +268,24 @@ TEST_F(HloScheduleTest, LatticeMatMul) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); } - HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary( - f32_2x2_, HloOpcode::kDot, params[2], params[3])); + HloInstruction* d00 = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 85ecbe8fdb34700ca738b99ddd9ea615afc35da3..5d5bef6b57b57fce4255a145634745b38dccacc7 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -17,7 +17,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace gpu { @@ -46,41 +48,100 @@ bool IsFusile(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kTranspose; } +bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { + if (constant->opcode() != HloOpcode::kConstant || + !ShapeUtil::IsScalar(constant->shape())) { + return false; + } + auto type = constant->shape().element_type(); + return type == F16 || type == F32 || type == F64; +} + } // namespace +/*static*/ bool GpuInstructionFusion::IsExpensive( + const HloInstruction& instruction) { + switch (instruction.opcode()) { + // We say that floating-point division is cheap on the GPU. + case HloOpcode::kDivide: + return !ShapeUtil::ElementIsFloating(instruction.shape()) && + InstructionFusion::IsExpensive(instruction); + + default: + return InstructionFusion::IsExpensive(instruction); + } +} + bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); // Check if we can use output fusion for (A @ B) * alpha - if (producer->opcode() == HloOpcode::kDot) { - if (consumer->opcode() == HloOpcode::kMultiply) { - CHECK_EQ(consumer->operand_count(), 2); - int64 other_operand_index = 1 - operand_index; - const HloInstruction* alpha = consumer->operand(other_operand_index); - if (alpha->opcode() == HloOpcode::kConstant && - ShapeUtil::IsScalar(alpha->shape())) { + if (consumer->operand_count() == 2 && + (producer->opcode() == HloOpcode::kDot || + (producer->opcode() == HloOpcode::kFusion && + producer->fused_expression_root()->opcode() == HloOpcode::kDot))) { + int64 other_operand_index = 1 - operand_index; + const HloInstruction* alpha = consumer->operand(other_operand_index); + HloInstruction* op1 = nullptr; + HloInstruction* op2 = nullptr; + if (consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() == HloInstruction::FusionKind::kLoop && + Match(consumer->fused_expression_root(), + match::Op() + .WithOpcode(HloOpcode::kMultiply) + .WithOperand(0, match::Op(&op1)) + .WithOperand(1, match::Op(&op2)))) { + CHECK(op1 != nullptr && op2 != nullptr); + // If 'consumer' is a fusion node, it should consist of a broadcast of a + // scalar constant fused into a multiply, but nothing more. So one operand + // should be a parameter, and the other should be a broadcast. + if (op1->opcode() != HloOpcode::kParameter) { + std::swap(op1, op2); + } + if (op1->opcode() != HloOpcode::kParameter || + op2->opcode() != HloOpcode::kBroadcast) { + return false; + } + if (IsIEEEFloatingPointScalarConstant(alpha)) { + return true; + } + } else if (consumer->opcode() == HloOpcode::kMultiply) { + // Fuse if 'alpha' is a broadcast of a scalar constant. + if (alpha->opcode() == HloOpcode::kBroadcast && + alpha->dimensions().empty() && + IsIEEEFloatingPointScalarConstant(alpha->operand(0))) { return true; } } } - // Only allow to fuse transpose into an output fusion. + // Only allow fusing transpose or broadcast into an output fusion that is + // implemented as a Gemm call. if (consumer->opcode() == HloOpcode::kFusion && - consumer->fusion_kind() == HloInstruction::FusionKind::kOutput) { - if (producer->opcode() != HloOpcode::kTranspose) { - return false; - } - // Check that the transpose is the operand of a dot. + consumer->fusion_kind() == HloInstruction::FusionKind::kOutput && + ImplementedAsGemm(*consumer)) { auto producer_operand_index = consumer->operand_index(producer); auto fused_parameter = consumer->fused_parameter(producer_operand_index); const std::vector& fused_parameter_users = fused_parameter->users(); - return (fused_parameter_users.size() == 1 && - fused_parameter_users[0]->opcode() == HloOpcode::kDot); + if (fused_parameter_users.size() != 1) { + return false; + } + if (producer->opcode() == HloOpcode::kTranspose) { + // Check that the transpose is an operand of a dot. + return fused_parameter_users[0]->opcode() == HloOpcode::kDot; + } + if (producer->opcode() == HloOpcode::kBroadcast) { + // Check that the broadcast is a broadcast of a scalar constant into a + // multiply. + return producer->dimensions().empty() && + IsIEEEFloatingPointScalarConstant(producer->operand(0)) && + fused_parameter_users[0]->opcode() == HloOpcode::kMultiply; + } } - // Output fusion is not currently supported on GPUs. + // Other output fusions are not currently supported on GPUs. if (producer->opcode() == HloOpcode::kFusion) { return false; } @@ -121,7 +182,9 @@ HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( if (IsReductionToVector(*consumer)) { return HloInstruction::FusionKind::kInput; } - if (producer->opcode() == HloOpcode::kDot) { + if (producer->opcode() == HloOpcode::kDot || + (producer->opcode() == HloOpcode::kFusion && + producer->fused_expression_root()->opcode() == HloOpcode::kDot)) { return HloInstruction::FusionKind::kOutput; } if (HloOpcode::kFusion == consumer->opcode()) { diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h index bb2990e6dfc9de0a11566bb3a2fb3a1b62498ffa..9fb06b0a244186484b1c17edf13bd28a4305a1a6 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h @@ -27,6 +27,8 @@ class GpuInstructionFusion : public InstructionFusion { explicit GpuInstructionFusion(bool may_duplicate) : InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate) {} + static bool IsExpensive(const HloInstruction& instruction); + bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; HloInstruction::FusionKind ChooseKind( diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 4b231c449f8f101127b4d30bfff20c69d8cef5c1..760e0e90f583d0e43975e23b731a40af75c7dc17 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -108,8 +108,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0)); + auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( + ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); @@ -125,8 +125,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0)); + auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( + ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); @@ -232,12 +232,13 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { auto module = tools::Parse(R"( HloModule test_module ENTRY OutputFusion { - constant = f32[] constant(3) + alpha = f32[] constant(3) + broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={} p0 = f32[4,3]{1,0} parameter(0) p1 = f32[4,3]{1,0} parameter(1) transpose = f32[3,4]{1,0} transpose(p1), dimensions={1, 0} - dot = f32[4,4]{1,0} dot(p0, transpose) - ROOT mul = f32[4,4] multiply(constant, dot) + dot = f32[4,4]{1,0} dot(p0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT mul = f32[4,4] multiply(dot, broadcast) })") .ValueOrDie(); @@ -247,10 +248,93 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Fusion()); + EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kOutput); EXPECT_THAT( root->fused_expression_root(), - op::Multiply(op::Parameter(), - op::Dot(op::Parameter(), op::Transpose(op::Parameter())))); + op::Multiply(op::Dot(op::Parameter(), op::Transpose(op::Parameter())), + op::Broadcast(op::Parameter()))); +} + +// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is +// duplicated and fused into both reduces. +TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { + auto module = tools::Parse(R"( + HloModule test_module + Add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + ENTRY TestComputation { + zero = f32[] constant(0) + one = f32[] constant(1) + p0 = f32[100] parameter(0) + recip = f32[100] divide(one, p0) + sum1 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add + sum2 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add + ROOT root = (f32[], f32[]) tuple(sum1, sum2) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion())); +} + +// Compute sum(100/p0), where p0 has type s32, twice. Check that the division +// is *not* duplicated and fused into both reduces, because we say that integer +// division is not cheap. +TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) { + auto module = tools::Parse(R"( + HloModule test_module + Add { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) + } + ENTRY TestComputation { + zero = s32[] constant(0) + one_hundred = s32[] constant(100) + p0 = s32[100] parameter(0) + recip = s32[100] divide(one_hundred, p0) + sum1 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add + sum2 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add + ROOT mul = (s32[], s32[]) tuple(sum1, sum2) + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, DotOutputFusionImpossible) { + auto module = tools::Parse(R"( + HloModule test_module + ENTRY NoOutputFusion { + alpha = f32[] constant(3) + broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={} + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[3,4]{1,0} parameter(1) + dot = f32[4,4]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + d = f32[4,4]{1,0} multiply(dot, dot) + ROOT mul = f32[4,4] multiply(d, broadcast) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop); + EXPECT_THAT(root->fused_expression_root(), + op::Multiply(op::Multiply(op::Parameter(), op::Parameter()), + op::Broadcast(op::Parameter()))); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 532d436ee82b985a4efe300f90223e1298e85765..22e715099526c20532bb298e84e50457d89f615e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -59,6 +59,25 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, !ShapeUtil::HasZeroElements(lhs_shape) && !ShapeUtil::HasZeroElements(rhs_shape); } + +bool DotImplementedAsGemm(const HloInstruction& dot) { + CHECK_EQ(dot.opcode(), HloOpcode::kDot); + const Shape& lhs_shape = dot.operand(0)->shape(); + const Shape& rhs_shape = dot.operand(1)->shape(); + + // If gemm can accept the operand shapes, use it rather than a custom + // kernel. + if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape())) { + // The size of the reduction dimension should match. The shape inference + // guarantees this invariant, so the check here is for programming + // errors. + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); + CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), + rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); + return true; + } + return false; +} } // namespace bool ImplementedAsGemm(const HloInstruction& hlo) { @@ -69,24 +88,7 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { // For certain types of Dot, we can call pre-canned BLAS gemm. if (hlo.opcode() == HloOpcode::kDot) { - const Shape& lhs_shape = hlo.operand(0)->shape(); - const Shape& rhs_shape = hlo.operand(1)->shape(); - - // If gemm can accept the operand shapes, use it rather than a custom - // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { - // The size of the reduction dimension should match. The shape inference - // guarantees this invariant, so the check here is for programming - // errors. - CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); - return true; - } - } - - if (hlo.opcode() == HloOpcode::kFusion && - hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && - hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { - return true; + return DotImplementedAsGemm(hlo); } if (hlo.opcode() == HloOpcode::kFusion && @@ -98,7 +100,7 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { dot = hlo.fused_expression_root()->operand(1); } if (dot->opcode() == HloOpcode::kDot) { - return ImplementedAsGemm(*dot); + return DotImplementedAsGemm(*dot); } } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 71aada080ae8df70bffce3e1854b5fbd833efd23..bb47a4280541ce2806472aa9365bb0ef38c0c3b3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/core/lib/core/status.h" @@ -116,6 +117,26 @@ Status IrEmitterNested::HandleParameter(HloInstruction* parameter) { Status IrEmitterNested::EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator) { + // For MOF we give the loop emitter an array for every output it should + // generate. + if (hlo.IsMultiOutputFusion()) { + std::vector target_arrays; + for (int64 i = 0, e = ShapeUtil::TupleElementCount(hlo.shape()); i != e; + ++i) { + target_arrays.push_back(GetIrArray(hlo, hlo, {i})); + } + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter(element_generator, target_arrays, &ir_builder_) + .EmitLoop()); + + std::vector tuple_operand_ptrs; + for (const llvm_ir::IrArray& array : target_arrays) { + tuple_operand_ptrs.push_back(array.GetBasePointer()); + } + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_, + module_); + return Status::OK(); + } return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &ir_builder_) .EmitLoop(); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 9f37235d32296828d3ca54f35517c3ee57607cfc..55d4c1d13d3ad41e09d48db70478cf5e6af59808 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -267,7 +267,10 @@ int ComputeMaxUnrollFactor(const HloInstruction* hlo) { // Find the largest possible power of two to unroll by. // TODO(kramerb): Make this smarter. - int64 num_elements = ShapeUtil::ElementsIn(hlo->shape()); + const Shape& element_shape = hlo->IsMultiOutputFusion() + ? ShapeUtil::GetSubshape(hlo->shape(), {0}) + : hlo->shape(); + int64 num_elements = ShapeUtil::ElementsIn(element_shape); for (int i = max_unroll_factor; i > 1; i /= 2) { if (num_elements % i == 0) { return i; @@ -565,12 +568,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { return Status::OK(); } - int unroll_factor = 1; - // TODO(kramerb): Unrolling multi-output loop fusions too. - if (!fusion->IsMultiOutputFusion()) { - CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop); - unroll_factor = ComputeMaxUnrollFactor(fusion); - } + CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop); + int unroll_factor = ComputeMaxUnrollFactor(fusion); thunk_sequence_->emplace_back(BuildKernelThunk(fusion, unroll_factor)); return IrEmitter::HandleFusion(fusion); @@ -1928,6 +1927,52 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { return IrEmitter::HandleSelect(select); } +Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { + if (hlo_module_config_.replica_count() != 1) { + // TODO(b/33011107): Support nontrivial cross replica sum on GPU. + return Unimplemented( + "CrossReplicaSum with >1 replica is not implemented on GPU."); + } + + // CRS with one operand and one replica is simply the identity function. + // Buffer assignment expects a copy, so that's what we do. + // + // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely + // in algebraic-simplifier, but currently on some platforms + // HloModuleConfig::num_replicas changes between when the module is compiled + // and when it's run. + if (crs->operand_count() == 1) { + CHECK(ShapeUtil::IsArray(crs->operand(0)->shape())) + << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); + thunk_sequence_->push_back(MakeUnique( + /*source_address=*/GetAllocationSlice(*crs->operand(0)), + /*destination_buffer=*/GetAllocationSlice(*crs), + /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs)); + return Status::OK(); + } + + // One-replica CRS with multiple operands produces a tuple of the inputs. + // Again, buffer assignment expects us to copy each. + std::vector> thunks; + std::vector tuple_element_buffers; + for (int64 i = 0; i < crs->operand_count(); ++i) { + tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment() + .GetUniqueSlice(crs, {i}) + .ValueOrDie()); + thunks.push_back(MakeUnique( + /*source_address=*/GetAllocationSlice(*crs->operand(i)), + /*destination_buffer=*/tuple_element_buffers.back(), + /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), crs)); + } + + // Output a tuple of the buffers above. + thunks.push_back(MakeUnique(tuple_element_buffers, + GetAllocationSlice(*crs), crs)); + thunk_sequence_->push_back( + MakeUnique(std::move(thunks), crs)); + return Status::OK(); +} + Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) { thunk_sequence_->emplace_back(BuildInfeedThunk(infeed)); return Status::OK(); @@ -2194,6 +2239,21 @@ std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( /*destination_buffer=*/GetAllocationSlice(*inst), inst); } +namespace { +double GetScalarConstantAsDouble(const Literal& literal) { + switch (literal.shape().element_type()) { + case F16: + return static_cast(literal.Get({})); + case F32: + return literal.Get({}); + case F64: + return literal.Get({}); + default: + LOG(FATAL) << "Unsupported type."; + } +} +} // namespace + std::unique_ptr IrEmitterUnnested::BuildGemmThunk( const HloInstruction* inst) { if (inst->opcode() == HloOpcode::kDot) { @@ -2206,65 +2266,48 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( lhs->shape(), // The shape of LHS. rhs->shape(), // The shape of RHS. inst->shape(), // The shape of the output. - false, // Do not transpose LHS. - false, // Do not transpose RHS. 1.0, // alpha. inst); } if (inst->opcode() == HloOpcode::kFusion) { - if (inst->fusion_kind() == HloInstruction::FusionKind::kOutput) { - const HloInstruction* mul = inst->fused_expression_root(); - const HloInstruction* dot = mul->operand(0); - const HloInstruction* alpha = mul->operand(1); - if (dot->opcode() != HloOpcode::kDot) { - std::swap(dot, alpha); - } - DCHECK(dot->opcode() == HloOpcode::kDot); - const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); - const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); - DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && - rhs_parameter->opcode() == HloOpcode::kParameter); - const HloInstruction* lhs = - inst->operand(lhs_parameter->parameter_number()); - const HloInstruction* rhs = - inst->operand(rhs_parameter->parameter_number()); - - return MakeUnique( - GetAllocationSlice(*lhs), // The buffer assigned to LHS. - GetAllocationSlice(*rhs), // The buffer assigned to RHS. - GetAllocationSlice(*mul), // The output buffer. - lhs->shape(), // The shape of LHS. - rhs->shape(), // The shape of RHS. - inst->shape(), // The shape of the output. - dot->operand(0)->IsRank2Transpose(), // Transpose LHS. - dot->operand(1)->IsRank2Transpose(), // Transpose RHS. - alpha->literal().Get({0}), // alpha. - inst); - } else { - const HloInstruction* dot = inst->fused_expression_root(); - DCHECK(dot->opcode() == HloOpcode::kDot); - const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); - const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); - DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && - rhs_parameter->opcode() == HloOpcode::kParameter); - const HloInstruction* lhs = - inst->operand(lhs_parameter->parameter_number()); - const HloInstruction* rhs = - inst->operand(rhs_parameter->parameter_number()); - - return MakeUnique( - GetAllocationSlice(*lhs), // The buffer assigned to LHS. - GetAllocationSlice(*rhs), // The buffer assigned to RHS. - GetAllocationSlice(*inst), // The output buffer. - lhs->shape(), // The shape of LHS. - rhs->shape(), // The shape of RHS. - inst->shape(), // The shape of the output. - dot->operand(0)->IsRank2Transpose(), // Transpose LHS. - dot->operand(1)->IsRank2Transpose(), // Transpose RHS. - 1.0, // Alpha. - inst); + CHECK_EQ(inst->fusion_kind(), HloInstruction::FusionKind::kOutput); + const HloInstruction* mul = inst->fused_expression_root(); + const HloInstruction* dot = mul->operand(0); + const HloInstruction* alpha = mul->operand(1); + if (dot->opcode() != HloOpcode::kDot) { + std::swap(dot, alpha); } + if (alpha->opcode() == HloOpcode::kBroadcast) { + alpha = alpha->operand(0); + } + alpha = inst->operand(alpha->parameter_number()); + // TODO(b/74185543): Remove the following if block once we support fusion + // with a non-constant as well. Then we will just always use the constant + // on the device. + if (alpha->opcode() == HloOpcode::kCopy) { + alpha = alpha->operand(0); + } + + DCHECK(dot->opcode() == HloOpcode::kDot); + const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); + const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); + DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && + rhs_parameter->opcode() == HloOpcode::kParameter); + const HloInstruction* lhs = + inst->operand(lhs_parameter->parameter_number()); + const HloInstruction* rhs = + inst->operand(rhs_parameter->parameter_number()); + + return MakeUnique( + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*inst), // The output buffer. + lhs->shape(), // The shape of LHS. + rhs->shape(), // The shape of RHS. + inst->shape(), // The shape of the output. + GetScalarConstantAsDouble(alpha->literal()), // alpha. + inst); } LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString(); @@ -2540,16 +2583,14 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( .EmitLoop(IrName(&hlo)); } - CHECK_EQ(unroll_factor, 1) - << "multi-output fusion does not support unrolling"; - // For multiple outputs fusion, we need to emit each operand and the root. std::vector output_arrays; for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { output_arrays.push_back(GetIrArray(hlo, hlo, {i})); } TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays, - launch_dimensions, &ir_builder_) + launch_dimensions, &ir_builder_, + unroll_factor) .EmitLoop(IrName(&hlo))); std::vector tuple_operand_ptrs; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index b41ab2162ab81f66e123a7055ca3ffc815c3ef88..e42c5e86862576bad1c8610652d1c50d2364cd83 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -76,6 +76,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleInfeed(HloInstruction* xla_infeed) override; Status HandleRng(HloInstruction* random) override; Status HandleSelect(HloInstruction* select) override; + Status HandleCrossReplicaSum(HloInstruction* crs) override; Status EmitTargetElementLoop( const HloInstruction& hlo, diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index d376ef7a245eb9ed86939f44c611b6dde5606b23..f56c1ce69f11ed79c8be76834269f29de93a9645 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -35,26 +35,38 @@ KernelThunk::KernelThunk( kernel_name_(kernel_name), unroll_factor_(unroll_factor) {} -tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) { +Status KernelThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { tensorflow::mutex_lock lock(mutex_); - if (loader_spec_) { - // Already initialized by another thread. - return tensorflow::Status::OK(); - } + if (!loader_spec_) { + loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); + tensorflow::StringPiece ptx = executable.ptx(); + // Convert tensorflow::StringPiece to se::port::StringPiece because + // StreamExecutor uses the latter. + loader_spec_->AddCudaPtxInMemory( + se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_); - loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); - tensorflow::StringPiece ptx = executable.ptx(); - // Convert tensorflow::StringPiece to se::port::StringPiece because - // StreamExecutor uses the latter. - loader_spec_->AddCudaPtxInMemory( - se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_); + if (!executable.cubin().empty()) { + loader_spec_->AddCudaCubinInMemory( + reinterpret_cast(executable.cubin().data()), + kernel_name_); + } + } - if (!executable.cubin().empty()) { - loader_spec_->AddCudaCubinInMemory( - reinterpret_cast(executable.cubin().data()), kernel_name_); + // Load the kernel into the device if necessary. + // + // We could alternatively do this within ExecuteOnStream, but doing it here + // lets the time spent loading the kernel not count towards our execution + // profiles. + auto it = kernel_cache_.find(executor); + if (kernel_cache_.end() == it) { + it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first; + if (!executor->GetKernel(*loader_spec_, &it->second)) { + return InternalError("Unable to load kernel %s", kernel_name_.c_str()); + } } - return tensorflow::Status::OK(); + return Status::OK(); } void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) { @@ -62,21 +74,18 @@ void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) { launch_dimensions_ = launch_dims; } -tensorflow::Status KernelThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { // Load the kernel. se::StreamExecutor* executor = stream->parent(); LaunchDimensions launch_dimensions; const se::KernelBase* kernel = nullptr; + { tensorflow::mutex_lock lock(mutex_); auto it = kernel_cache_.find(executor); - if (kernel_cache_.end() == it) { - it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first; - if (!executor->GetKernel(*loader_spec_, &it->second)) { - return InternalError("Unable to load kernel %s", kernel_name_.c_str()); - } - } + CHECK(it != kernel_cache_.end()) + << "Initialize() not called for StreamExecutor " << executor; launch_dimensions = launch_dimensions_; kernel = &it->second; } @@ -97,7 +106,7 @@ tensorflow::Status KernelThunk::ExecuteOnStream( *kernel_args)) { return InternalError("Unable to launch kernel %s", kernel_name_.c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index b556befe66b6bebba1a958f553f0a9b2c4eebbe4..7def27e189b66747569344a3dbe5c0c446f903be 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -57,11 +57,12 @@ class KernelThunk : public Thunk { int unroll_factor() const { return unroll_factor_; } void SetLaunchDimensions(const LaunchDimensions& launch_dims); - tensorflow::Status Initialize(const GpuExecutable& executable) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; // Executes the kernel for the thunk on "stream", which must be non-null. - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: // Buffers passed to the kernel as arguments. @@ -83,7 +84,8 @@ class KernelThunk : public Thunk { mutable tensorflow::mutex mutex_; std::unique_ptr loader_spec_ GUARDED_BY(mutex_); - // Loaded kernels for each `StreamExecutor` + // Loaded kernels for each `StreamExecutor`. Requires pointer stability of + // values. std::unordered_map kernel_cache_ GUARDED_BY(mutex_); }; diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index 86c4ac18b0501c38aaaae5a007bddcf261ca338f..7de8f9e1ee922bdbf65fd1299702482e1843f17e 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -47,7 +47,6 @@ cc_library( "@llvm//:scalar", "@llvm//:support", "@llvm//:target", - "@llvm//:transform_utils", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index d70cb07c57d48c0faed2cdc5ea9fc5ce5fb32be0..a4e4e85bf3d2c197cfc691b7fca0920aa6571729 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -77,8 +77,7 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, // Since CUDA 9.0, all GPU versions are included in a single file const char* unified_libdevice_filename = "libdevice.10.bc"; std::vector unified_libdevice_files; - const tensorflow::Status status = - tensorflow::Env::Default()->GetMatchingPaths( + const Status status = tensorflow::Env::Default()->GetMatchingPaths( tensorflow::io::JoinPath(libdevice_dir_path, unified_libdevice_filename), &unified_libdevice_files); if (status.ok() && unified_libdevice_files.size() == 1) { @@ -273,7 +272,7 @@ string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) { codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( llvm::Triple(module->getTargetTriple()))); - target_machine->addPassesToEmitFile(codegen_passes, pstream, + target_machine->addPassesToEmitFile(codegen_passes, pstream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile); codegen_passes.run(*module); } @@ -311,11 +310,11 @@ bool CouldNeedLibdevice(const llvm::Module& module) { } // Links libdevice into the given module if the module needs libdevice. -tensorflow::Status LinkLibdeviceIfNecessary( - llvm::Module* module, std::pair compute_capability, - const string& libdevice_dir_path) { +Status LinkLibdeviceIfNecessary(llvm::Module* module, + std::pair compute_capability, + const string& libdevice_dir_path) { if (!CouldNeedLibdevice(*module)) { - return tensorflow::Status::OK(); + return Status::OK(); } llvm::Linker linker(*module); @@ -336,7 +335,7 @@ tensorflow::Status LinkLibdeviceIfNecessary( return tensorflow::errors::Internal(tensorflow::strings::StrCat( "Error linking libdevice from ", libdevice_path)); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr CompileModuleToPtx(llvm::Module* module, diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 7bda4e2fcd469bd430e5ef1846251c8504225383..c8f0d4185c63c5bafca6f30acab31cbe8e987277 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -370,26 +370,38 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( return true; } -StatusOr PadInsertion::Run(HloModule* module) { +StatusOr PadInsertion::RunOnComputation(HloComputation* computation) { bool changed = false; - for (HloInstruction* instruction : - module->entry_computation()->MakeInstructionPostOrder()) { - if (IsCustomCallToDnnConvolution(*instruction)) { - const auto& target = instruction->custom_call_target(); - if (target == kCudnnConvForwardCallTarget) { - changed |= CanonicalizeForwardConvolution(instruction); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - changed |= CanonicalizeBackwardFilterConvolution(instruction); - } else if (target == kCudnnConvBackwardInputCallTarget) { - changed |= CanonicalizeBackwardInputConvolution(instruction); - } else { - LOG(FATAL) << "Unknown custom call target for cudnn conv: " - << instruction->ToString(); - } + std::vector convs; + for (auto* instr : computation->instructions()) { + if (IsCustomCallToDnnConvolution(*instr)) { + convs.push_back(instr); + } + } + for (HloInstruction* instruction : convs) { + const auto& target = instruction->custom_call_target(); + if (target == kCudnnConvForwardCallTarget) { + changed |= CanonicalizeForwardConvolution(instruction); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + changed |= CanonicalizeBackwardFilterConvolution(instruction); + } else if (target == kCudnnConvBackwardInputCallTarget) { + changed |= CanonicalizeBackwardInputConvolution(instruction); + } else { + LOG(FATAL) << "Unknown custom call target for cudnn conv: " + << instruction->ToString(); } } return changed; } +StatusOr PadInsertion::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + changed |= result; + } + return changed; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h index 5e1c68701daa02eba64f3e34933ce373a496c1b8..67e51509e4c717951c83c7e41943af1de762dee0 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h @@ -31,6 +31,7 @@ class PadInsertion : public HloPassInterface { StatusOr Run(HloModule* module) override; private: + StatusOr RunOnComputation(HloComputation* computation); // Returns if any changes are made to the parent computation. bool CanonicalizeForwardConvolution(HloInstruction* conv); bool CanonicalizeBackwardFilterConvolution(HloInstruction* backward_conv); diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc index c8510808f10a731af90154447bd3e1e037db6348..88cb10883e97ae663dc492ad088e6daf9133d7f5 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc @@ -20,24 +20,24 @@ limitations under the License. namespace xla { namespace gpu { -SequentialThunk::SequentialThunk(std::vector>&& thunks, +SequentialThunk::SequentialThunk(std::vector> thunks, const HloInstruction* hlo) : Thunk(Kind::kSequential, hlo), thunks_(std::move(thunks)) {} -tensorflow::Status SequentialThunk::Initialize( - const GpuExecutable& executable) { +Status SequentialThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { for (auto& thunk : thunks_) { - TF_RETURN_IF_ERROR(thunk->Initialize(executable)); + TF_RETURN_IF_ERROR(thunk->Initialize(executable, executor)); } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status SequentialThunk::ExecuteOnStream( +Status SequentialThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { for (const auto& thunk : thunks_) { TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h index df17b8d67b80321c7088243eae46e7a723b4ede9..135f79e413dfaa27f2f2264e0daa3beb3c305e0f 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h @@ -31,16 +31,17 @@ namespace gpu { // require multiple kernel launches or library calls. class SequentialThunk : public Thunk { public: - SequentialThunk(std::vector>&& thunks, + SequentialThunk(std::vector> thunks, const HloInstruction* hlo); SequentialThunk(const SequentialThunk&) = delete; SequentialThunk& operator=(const SequentialThunk&) = delete; const std::vector>& thunks() const { return thunks_; } - tensorflow::Status Initialize(const GpuExecutable& executable) override; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: // The list of sub-thunks. diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 8c98956f1a9b2a0bb1d304a27eb8c8cfcf610784..696fa7e0194032b5c78bf11383c3280a62de07fa 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -28,6 +28,15 @@ namespace gpu { class StreamAssignmentTest : public HloTestBase { protected: + std::unique_ptr CreateNewModule() { + HloModuleConfig config; + auto debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_disable_multi_streaming(false); + config.set_debug_options(debug_options); + return MakeUnique("test_module", VersionedComputationHandle(), + config); + } + // Pre-canned shapes. Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2}); }; @@ -41,9 +50,9 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) { HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z)); + HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -60,9 +69,9 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) { HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x)); + HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); @@ -91,24 +100,24 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); } - HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary( - f32_2x2_, HloOpcode::kDot, params[2], params[3])); + HloInstruction* d00 = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index a0c785ed913109e987d058124c8ef49019c98500..931c0bffab850362dbd2df975657dd47d9cbd3ae 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -70,11 +70,14 @@ class Thunk { Kind kind() const { return kind_; } const HloInstruction* hlo_instruction() const { return hlo_instruction_; } - // Prepares for executing the thunk. This method is called only once over - // Thunk's lifetime. For example, KernelThunk::Initialize loads the PTX of a - // kernel, which is the same in every execution. - virtual tensorflow::Status Initialize(const GpuExecutable& executable) { - return tensorflow::Status::OK(); + // Prepares the thunk for execution on the given StreamExecutor. + // + // This may be called multiple times. Its main purpose is to give us a chance + // to do initialization outside of ExecuteOnStream() so that the + // time spent initializing doesn't count towards our execution profile. + virtual Status Initialize(const GpuExecutable& /*executable*/, + se::StreamExecutor* /*executor*/) { + return Status::OK(); } // Users of Thunk should call ShouldHaltAllActivityBeforeRunning(stream) @@ -89,21 +92,13 @@ class Thunk { return false; } - // Indicates whether thunks scheduled after this one should wait for this one - // to complete before running. For example, a convolution thunk creates a - // scratch allocator, then kicks off a convolution in cudnn via the stream - // executor. When the stream executor call returns, the scratch allocator goes - // out of scope, and the scratch memory is deallocated. In this case, the - // convolution thunk needs to return true so that future thunks wait for the - // convolution thunk to avoid reusing the deallocated memory until the - // convolution thunk is done with it. - virtual bool ShouldBlockFutureThunks() { return false; } - // Execute the kernel for the thunk on the given stream. This method must be // called after Initialize and can be called multiple times over Thunk's // lifetime. Stream argument must be non-null. - virtual tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) = 0; + // + // Precondition: Initialize(stream->parent()) has been called. + virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) = 0; private: Kind kind_; diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index ecb54857ccc40ead21e5a18d79a37b545680021d..97cb04c38fbf18e516857f5269c984696ca204c3 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -20,8 +20,8 @@ limitations under the License. namespace xla { namespace gpu { -tensorflow::Status TupleThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { std::vector tuple_element_buffer_addresses; for (BufferAllocation::Slice tuple_element_buffer : tuple_element_buffers_) { tuple_element_buffer_addresses.push_back( @@ -40,7 +40,7 @@ tensorflow::Status TupleThunk::ExecuteOnStream( tuple_element_buffer_addresses.data(), dest_buffer_address.opaque(), sizeof(void*) * tuple_element_buffer_addresses.size()); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h index 8b459c29a136a6e7853e68a1bead7d12c0d08ad0..951f809b51937c97a6e7de0345ec58a8b66a4242 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h @@ -45,8 +45,8 @@ class TupleThunk : public Thunk { TupleThunk(const TupleThunk&) = delete; TupleThunk& operator=(const TupleThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const std::vector tuple_element_buffers_; diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index a9f3d619a3ffd6d849572355e2902375e43508fa..30b9640c4c75dae61e9a90da5fb10e9d4a90cd26 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -34,9 +34,11 @@ WhileThunk::WhileThunk( body_thunk_sequence_( MakeUnique(std::move(*body_thunk_sequence), hlo)) {} -Status WhileThunk::Initialize(const GpuExecutable& executable) { - TF_RETURN_IF_ERROR(condition_thunk_sequence_->Initialize(executable)); - TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable)); +Status WhileThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { + TF_RETURN_IF_ERROR( + condition_thunk_sequence_->Initialize(executable, executor)); + TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index e589ca78a7ea00e7592d6e09ead9c270f902702f..22176685a92df9c95b10f755b209309843c0fa3a 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -45,7 +45,8 @@ class WhileThunk : public Thunk { WhileThunk(const WhileThunk&) = delete; WhileThunk& operator=(const WhileThunk&) = delete; - Status Initialize(const GpuExecutable& executable) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream) override; diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index e6caec8625f0d622dbb92bcc20802d254fe23f94..ad55728c45599c801aad7e12fac95ae9f0c4fc3b 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -144,7 +144,7 @@ class ExprTree { TF_RETURN_IF_ERROR(pair.second->Match(instruction->operand(pair.first), tagged_instructions)); } - return tensorflow::Status::OK(); + return Status::OK(); } private: @@ -169,7 +169,7 @@ class MatcherBase { // Attempts to match each ExprTree in 'expr_trees_'. // Returns OK on the first successful match, error status otherwise. - virtual tensorflow::Status Run() { + virtual Status Run() { Status status; for (const ExprTree& expr_tree : expr_trees_) { status = MatchExprTree(expr_tree); @@ -201,7 +201,7 @@ class MatcherBase { } else if (type == S64) { *const_value = literal.GetFirstElement(); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr GetTaggedInstruction( @@ -315,7 +315,7 @@ class WhileConditionComputationMatcher : public MatcherBase { gte_fusion_param0->name().c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } const HloComputation* computation_; @@ -379,7 +379,7 @@ class WhileInitOperandMatcher : public MatcherBase { GetTaggedInstruction("loop_start", tagged_instructions)); TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_start_)); - return tensorflow::Status::OK(); + return Status::OK(); } const HloInstruction* while_hlo_; @@ -477,7 +477,7 @@ class WhileBodyComputationMatcher : public MatcherBase { } } } - return tensorflow::Status::OK(); + return Status::OK(); } const HloComputation* computation_; diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 3dd4c4a0794e5c41b877078c4e69c6c9584ce6c0..06a5e0351b63270b61b998ca2211f480f256f759 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -32,7 +31,7 @@ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, const SequentialHloOrdering::HloModuleSequence& module_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, const Options& options) { + const BufferValue::SizeFunction& size_fn, const Options& options) { HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence); const HloComputation* entry_computation = module.entry_computation(); const std::vector& instruction_sequence = @@ -47,7 +46,7 @@ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, const Options& options) { + const BufferValue::SizeFunction& size_fn, const Options& options) { HeapSimulator heap(std::move(algorithm), size_fn, options, /*module_sequence=*/nullptr); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, @@ -73,11 +72,11 @@ Status HeapSimulator::RunComputation( // 'used_buffers' is the reverse map - it tracks which buffers were used by an // instruction, so that we can remove the instructions from a buffer's live // set after they are visited. - FlatMap> live_buffers; - FlatMap> used_buffers; + FlatMap> live_buffers; + FlatMap> used_buffers; auto add_user_to_buffer = [this, &live_buffers, &used_buffers]( const HloInstruction* user, - const LogicalBuffer* buffer) { + const BufferValue* buffer) { if (!IgnoreBuffer(buffer)) { VLOG(4) << " Adding user " << user->name() << " to buffer " << buffer->ToString(); @@ -96,7 +95,7 @@ Status HeapSimulator::RunComputation( const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet(); for (const HloInstruction* user : instruction->users()) { if (user->opcode() != HloOpcode::kGetTupleElement) { - for (const LogicalBuffer* buffer : buffer_set) { + for (const BufferValue* buffer : buffer_set) { add_user_to_buffer(user, buffer); } } else { @@ -104,12 +103,12 @@ Status HeapSimulator::RunComputation( // alive. It only needs the buffers that relate to the element its // extracting, and the tuple it's extracting from, but not the buffers // for the other elements. - for (const LogicalBuffer* buffer : points_to.element({})) { + for (const BufferValue* buffer : points_to.element({})) { add_user_to_buffer(user, buffer); } const PointsToSet& gte_points_to = points_to_analysis.GetPointsToSet(user); - for (const LogicalBuffer* buffer : gte_points_to.CreateFlattenedSet()) { + for (const BufferValue* buffer : gte_points_to.CreateFlattenedSet()) { add_user_to_buffer(user, buffer); } } @@ -117,24 +116,25 @@ Status HeapSimulator::RunComputation( } const HloInstruction* root = computation.root_instruction(); - auto output_source_buffers = - points_to_analysis.GetPointsToSet(root).CreateFlattenedSet(); + BufferValueCompactPointerSet output_source_buffers = + ToBufferValueCompactPointerSet( + points_to_analysis.GetPointsToSet(root).CreateFlattenedSet()); - std::vector dead_buffers_to_free; - std::vector operand_buffers_to_free; + std::vector dead_buffers_to_free; + std::vector operand_buffers_to_free; for (const HloInstruction* instruction : instruction_sequence) { const TuplePointsToAnalysis::BufferDefinitionVector& buffers_defined_by_instruction = points_to_analysis.GetBuffersDefinedByInstruction(instruction); VLOG(3) << "Instruction: " << instruction->ToString(); - for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { + for (const BufferValue* buffer : buffers_defined_by_instruction) { VLOG(4) << " Defines: " << buffer->ToString() << (IgnoreBuffer(buffer) ? " (Ignored)" : ""); } dead_buffers_to_free.clear(); - for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { + for (const BufferValue* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; } @@ -161,7 +161,7 @@ Status HeapSimulator::RunComputation( // have no instructions left to visit are moved from live_buffers to // operand_buffers_to_free. operand_buffers_to_free.clear(); - for (const LogicalBuffer* operand_buffer : used_buffers[instruction]) { + for (const BufferValue* operand_buffer : used_buffers[instruction]) { if (IgnoreBuffer(operand_buffer)) { continue; } @@ -177,7 +177,7 @@ Status HeapSimulator::RunComputation( } // Sort to get a deterministic iteration order. std::sort(operand_buffers_to_free.begin(), operand_buffers_to_free.end(), - [](const LogicalBuffer* x, const LogicalBuffer* y) { + [](const BufferValue* x, const BufferValue* y) { return x->id() < y->id(); }); @@ -188,7 +188,7 @@ Status HeapSimulator::RunComputation( // // INVARIANT: Either Alloc or ShareBuffer will be called for each buffer // that we should assign. - for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { + for (const BufferValue* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; } @@ -199,12 +199,12 @@ Status HeapSimulator::RunComputation( // we must be the last user of the buffer. bool shared = false; if (options_.may_reuse_operand_buffers) { - for (const LogicalBuffer* operand_buffer : operand_buffers_to_free) { + for (const BufferValue* operand_buffer : operand_buffers_to_free) { if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && buffer->instruction()->opcode() != HloOpcode::kCopy && - CanShareOperandBufferWithUser( + points_to_analysis.CanShareOperandBufferWithUser( operand_buffer->instruction(), operand_buffer->index(), - buffer->instruction(), buffer->index(), points_to_analysis)) { + buffer->instruction(), buffer->index())) { VLOG(3) << " Sharing: " << buffer->ToString() << " with " << operand_buffer->ToString(); ShareBuffer(buffer, operand_buffer, instruction); @@ -248,11 +248,11 @@ Status HeapSimulator::RunComputation( // Free buffers that are no longer live. This is the earliest point that we // can de-allocate; right after the last use of the buffer. - for (const LogicalBuffer* buffer : dead_buffers_to_free) { + for (const BufferValue* buffer : dead_buffers_to_free) { VLOG(3) << " Freeing dead: " << buffer->ToString(); Free(buffer, instruction); } - for (const LogicalBuffer* buffer : operand_buffers_to_free) { + for (const BufferValue* buffer : operand_buffers_to_free) { VLOG(3) << " Freeing operand: " << buffer->ToString(); Free(buffer, instruction); } @@ -261,10 +261,10 @@ Status HeapSimulator::RunComputation( // Any remaining live buffers must be entry parameters or output source // buffers, which had a nullptr sentry added. Free them now, in a // deterministic order. - std::vector to_free; + std::vector to_free; to_free.reserve(live_buffers.size()); for (const auto& buffer_pending : live_buffers) { - const LogicalBuffer* buffer = buffer_pending.first; + const BufferValue* buffer = buffer_pending.first; const FlatSet& pending = buffer_pending.second; CHECK_EQ(pending.size(), 1) << *buffer; CHECK(*pending.begin() == nullptr) << *buffer; @@ -272,10 +272,10 @@ Status HeapSimulator::RunComputation( } std::sort(to_free.begin(), to_free.end(), - [](const LogicalBuffer* x, const LogicalBuffer* y) { + [](const BufferValue* x, const BufferValue* y) { return x->id() < y->id(); }); - for (const LogicalBuffer* buffer : to_free) { + for (const BufferValue* buffer : to_free) { VLOG(3) << "Freeing pending: " << buffer->ToString(); Free(buffer, root); } @@ -285,7 +285,7 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr algorithm, - const LogicalBuffer::SizeFunction& size_fn, const Options& options, + const BufferValue::SizeFunction& size_fn, const Options& options, const SequentialHloOrdering::HloModuleSequence* module_sequence) : no_fragmentation_stats_(MakeUnique()), algorithm_(std::move(algorithm)), @@ -297,7 +297,7 @@ HeapSimulator::HeapSimulator( HeapSimulator::~HeapSimulator() {} -bool HeapSimulator::IgnoreBuffer(const LogicalBuffer* buffer) const { +bool HeapSimulator::IgnoreBuffer(const BufferValue* buffer) const { // Buffers for constants are ignored unless the alloc_constants option is // set. Also ignore buffers that we're not meant to assign. // @@ -311,7 +311,7 @@ bool HeapSimulator::IgnoreBuffer(const LogicalBuffer* buffer) const { } // Alloc always calls the underlying heap algorithm. -void HeapSimulator::Alloc(const LogicalBuffer* buffer, +void HeapSimulator::Alloc(const BufferValue* buffer, const HloInstruction* instruction) { CHECK(allocated_buffers_.count(buffer) == 0) << "Alloc called on allocated buffer: " << *buffer; @@ -331,7 +331,7 @@ void HeapSimulator::Alloc(const LogicalBuffer* buffer, // buffers whose group liveness has expired. Shared group liveness is tracked // by maintaining a refcount; the Free call on the last buffer in the group // causes Free to be called on the underlying algorithm. -void HeapSimulator::Free(const LogicalBuffer* buffer, +void HeapSimulator::Free(const BufferValue* buffer, const HloInstruction* instruction) { auto shared_it = shared_buffers_.find(buffer); if (shared_it != shared_buffers_.end()) { @@ -362,8 +362,8 @@ void HeapSimulator::Free(const LogicalBuffer* buffer, // The 'buffer' must be a non-allocated, non-freed buffer, just like in calls to // Alloc. The 'shared' buffer must be a previously allocated or shared buffer. // Both 'buffer' and 'shared' will be associated with the same SharedGroup. -void HeapSimulator::ShareBuffer(const LogicalBuffer* buffer, - const LogicalBuffer* shared, +void HeapSimulator::ShareBuffer(const BufferValue* buffer, + const BufferValue* shared, const HloInstruction* instruction) { CHECK_LE(size_fn_(*buffer), size_fn_(*shared)) << "ShareBuffer oversized buffer" << *buffer << " shared: " << *shared; @@ -374,7 +374,7 @@ void HeapSimulator::ShareBuffer(const LogicalBuffer* buffer, CHECK(freed_buffers_.count(shared) == 0) << "ShareBuffer called on freed shared buffer: " << *shared; - const LogicalBuffer* canonical = nullptr; + const BufferValue* canonical = nullptr; auto shared_it = shared_buffers_.find(shared); if (shared_it != shared_buffers_.end()) { // The 'shared' buffer already has a group; it might be the canonical, but @@ -408,7 +408,7 @@ HeapSimulator::Result HeapSimulator::Finish() { // collecting statistics, e.g. NoFragmentationStatsHeap. if (!result.chunk_map.empty()) { for (const auto& share_pair : shared_buffers_) { - const LogicalBuffer* buffer = share_pair.first; + const BufferValue* buffer = share_pair.first; std::shared_ptr group = share_pair.second; if (buffer != group->canonical) { // The canonical must already exist in the chunk_map, since we called @@ -437,9 +437,9 @@ HeapSimulator::Result HeapSimulator::Finish() { } void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, - const LogicalBuffer* buffer, + const BufferValue* buffer, const HloInstruction* instruction, - const LogicalBuffer* share_with_canonical) { + const BufferValue* share_with_canonical) { HeapSimulatorTrace::Event* event = debug_trace_.add_events(); event->set_kind(kind); event->set_buffer_id(buffer->id()); @@ -453,14 +453,14 @@ void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, } } -void NoFragmentationStatsHeap::Alloc(const LogicalBuffer* buffer, int64 size) { +void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) { current_heap_size_ += size; if (current_heap_size_ > max_heap_size_) { max_heap_size_ = current_heap_size_; } } -void NoFragmentationStatsHeap::Free(const LogicalBuffer* buffer, int64 size) { +void NoFragmentationStatsHeap::Free(const BufferValue* buffer, int64 size) { current_heap_size_ -= size; } @@ -472,12 +472,12 @@ HeapSimulator::Result NoFragmentationStatsHeap::Finish() { return result; } -void DecreasingSizeRunsHeap::Alloc(const LogicalBuffer* buffer, int64 size) { +void DecreasingSizeRunsHeap::Alloc(const BufferValue* buffer, int64 size) { SetMode(kAlloc); run_.emplace_back(Op{buffer, size}); } -void DecreasingSizeRunsHeap::Free(const LogicalBuffer* buffer, int64 size) { +void DecreasingSizeRunsHeap::Free(const BufferValue* buffer, int64 size) { CHECK(mode_ != kInit) << "Free called on empty heap: " << *buffer; SetMode(kFree); run_.emplace_back(Op{buffer, size}); @@ -518,7 +518,7 @@ void DecreasingSizeRunsHeap::CallAndDrainRun() { run_.clear(); } -void LazyBestFitHeap::Alloc(const LogicalBuffer* buffer, int64 size) { +void LazyBestFitHeap::Alloc(const BufferValue* buffer, int64 size) { // Degenerate case: 0-sized buffers are always allocated at offset 0. if (size == 0) { result_.chunk_map.emplace(buffer, Chunk{0, 0}); @@ -586,7 +586,7 @@ void LazyBestFitHeap::Alloc(const LogicalBuffer* buffer, int64 size) { result_.chunk_map.emplace(buffer, Chunk{kLazyAllocOffset, size}); } -void LazyBestFitHeap::Free(const LogicalBuffer* buffer, int64 size) { +void LazyBestFitHeap::Free(const BufferValue* buffer, int64 size) { auto alloc_it = result_.chunk_map.find(buffer); CHECK(alloc_it != result_.chunk_map.end()) << "Free called on non-allocated buffer: " << *buffer; diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 636f19dd39f09721bd82fc4b44785f196f281ad7..8b2b43a37a5c41d334e5338c6a6fad160f03a51e 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -21,11 +21,12 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -43,7 +44,7 @@ class HeapAlgorithm; // don't need to return the assignment of buffer offsets until the very end. class HeapSimulator { public: - // Chunk represents a contiguous piece of memory. Each LogicalBuffer will be + // Chunk represents a contiguous piece of memory. Each BufferValue will be // associated with a chunk in the assignment result. struct Chunk { int64 offset; @@ -55,7 +56,7 @@ class HeapSimulator { // Result represents the result of the heap simulation. struct Result { // The assignment of buffers to chunks. - tensorflow::gtl::FlatMap chunk_map; + tensorflow::gtl::FlatMap chunk_map; // The total size in bytes of the heap, containing all assigned chunks. int64 heap_size = 0; @@ -81,7 +82,7 @@ class HeapSimulator { bool alloc_constants; // If 'buffers_to_assign' is provided, only those buffers are assigned // offsets, otherwise all buffers defined by the instructions are assigned. - const tensorflow::gtl::FlatSet* buffers_to_assign; + const BufferValueFlatSet* buffers_to_assign; }; // Run the heap simulation with the given algorithm, assuming the given @@ -97,7 +98,7 @@ class HeapSimulator { std::unique_ptr algorithm, const HloModule& module, const SequentialHloOrdering::HloModuleSequence& module_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, + const BufferValue::SizeFunction& size_fn, const Options& options = Options()); // Same as above, but runs on a single computation. The 'instruction_sequence' @@ -109,7 +110,7 @@ class HeapSimulator { const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, + const BufferValue::SizeFunction& size_fn, const Options& options = Options()); private: @@ -118,7 +119,7 @@ class HeapSimulator { // be run recursively. I.e. the simulation is run over the whole module. HeapSimulator( std::unique_ptr algorithm, - const LogicalBuffer::SizeFunction& size_fn, const Options& options, + const BufferValue::SizeFunction& size_fn, const Options& options, const SequentialHloOrdering::HloModuleSequence* module_sequence); ~HeapSimulator(); @@ -127,21 +128,21 @@ class HeapSimulator { const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis); - bool IgnoreBuffer(const LogicalBuffer* buffer) const; - void Alloc(const LogicalBuffer* buffer, const HloInstruction* instruction); - void Free(const LogicalBuffer* buffer, const HloInstruction* instruction); - void ShareBuffer(const LogicalBuffer* buffer, const LogicalBuffer* shared, + bool IgnoreBuffer(const BufferValue* buffer) const; + void Alloc(const BufferValue* buffer, const HloInstruction* instruction); + void Free(const BufferValue* buffer, const HloInstruction* instruction); + void ShareBuffer(const BufferValue* buffer, const BufferValue* shared, const HloInstruction* instruction); Result Finish(); void FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, - const LogicalBuffer* buffer, + const BufferValue* buffer, const HloInstruction* instruction, - const LogicalBuffer* shared_with_canonical); + const BufferValue* shared_with_canonical); const std::unique_ptr no_fragmentation_stats_; const std::unique_ptr algorithm_; - const LogicalBuffer::SizeFunction size_fn_; + const BufferValue::SizeFunction size_fn_; const Options options_; const SequentialHloOrdering::HloModuleSequence* module_sequence_; @@ -160,15 +161,15 @@ class HeapSimulator { // The shared_buffers_ map associates each shared buffer (including the // canonical) to its SharedGroup control block. struct SharedGroup { - const LogicalBuffer* canonical = nullptr; + const BufferValue* canonical = nullptr; int64 refcount = 0; }; - tensorflow::gtl::FlatMap> + tensorflow::gtl::FlatMap> shared_buffers_; // Hold some sets for error-checking the sequence of Alloc and Free calls. - tensorflow::gtl::FlatSet allocated_buffers_; - tensorflow::gtl::FlatSet freed_buffers_; + tensorflow::gtl::FlatSet allocated_buffers_; + tensorflow::gtl::FlatSet freed_buffers_; // Debugging information filled in while the heap simulator runs. HeapSimulatorTrace debug_trace_; @@ -186,10 +187,10 @@ class HeapAlgorithm { virtual ~HeapAlgorithm() = default; // Alloc allocates a buffer of 'size' bytes. - virtual void Alloc(const LogicalBuffer* buffer, int64 size) = 0; + virtual void Alloc(const BufferValue* buffer, int64 size) = 0; // Free de-allocates a previously allocated buffer. - virtual void Free(const LogicalBuffer* buffer, int64 size) = 0; + virtual void Free(const BufferValue* buffer, int64 size) = 0; // Finish collects the buffer offset assignment results. Free may only be // called once, after the Alloc and Free calls. @@ -205,8 +206,8 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { NoFragmentationStatsHeap() = default; ~NoFragmentationStatsHeap() override = default; - void Alloc(const LogicalBuffer* buffer, int64 size) override; - void Free(const LogicalBuffer* buffer, int64 size) override; + void Alloc(const BufferValue* buffer, int64 size) override; + void Free(const BufferValue* buffer, int64 size) override; Result Finish() override; private: @@ -223,14 +224,14 @@ class DecreasingSizeRunsHeap : public HeapAlgorithm { : algorithm_(std::move(algorithm)) {} ~DecreasingSizeRunsHeap() override {} - void Alloc(const LogicalBuffer* buffer, int64 size) override; - void Free(const LogicalBuffer* buffer, int64 size) override; + void Alloc(const BufferValue* buffer, int64 size) override; + void Free(const BufferValue* buffer, int64 size) override; Result Finish() override; private: // A single Alloc or Free operation that we've buffered in run_. struct Op { - const LogicalBuffer* buffer; + const BufferValue* buffer; int64 size; }; @@ -266,8 +267,8 @@ class LazyBestFitHeap : public HeapAlgorithm { LazyBestFitHeap(int64 alignment) : alignment_(alignment) {} ~LazyBestFitHeap() override {} - void Alloc(const LogicalBuffer* buffer, int64 size) override; - void Free(const LogicalBuffer* buffer, int64 size) override; + void Alloc(const BufferValue* buffer, int64 size) override; + void Free(const BufferValue* buffer, int64 size) override; Result Finish() override; private: diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index fd56a603bb6f849b1c1f1578fe7395d9b372e2d5..6271652412c2979ff926702f12722102344b0dfb 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -39,7 +39,7 @@ const char kFree[] = "Free"; const char kFinish[] = "Finish"; // CallSequence records a sequence of Alloc/Free/Finish calls. -using CallSequence = std::vector>; +using CallSequence = std::vector>; // HeapCallRecorder is a dummy heap algorithm that simply records its calls. class HeapCallRecorder : public HeapAlgorithm { @@ -47,7 +47,7 @@ class HeapCallRecorder : public HeapAlgorithm { explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {} ~HeapCallRecorder() override {} - void Alloc(const LogicalBuffer* buffer, int64 size) override { + void Alloc(const BufferValue* buffer, int64 size) override { calls_->emplace_back(kAlloc, buffer); // Instead of assigning a real offset, we set the cardinality of the Alloc // call. This isn't a valid assignment, but allows us to easily test for @@ -55,7 +55,7 @@ class HeapCallRecorder : public HeapAlgorithm { const int64 offset = result_.chunk_map.size(); result_.chunk_map.emplace(buffer, Chunk{offset, size}); } - void Free(const LogicalBuffer* buffer, int64 size) override { + void Free(const BufferValue* buffer, int64 size) override { calls_->emplace_back(kFree, buffer); } Result Finish() override { @@ -118,7 +118,7 @@ class HeapSimulatorTracker { // Hack the size_fn so that it returns a decreasing value as we step through // the sequence. This lets us ensure the Alloc calls are in the sequence - // order. The Free calls are sorted by LogicalBuffer.id, which is at least + // order. The Free calls are sorted by BufferValue.id, which is at least // deterministic. auto size_fn = [&reverse_position](const BufferValue& buffer) { return reverse_position[buffer.instruction()]; @@ -133,8 +133,8 @@ class HeapSimulatorTracker { HloModule* module() { return module_.get(); } // Returns the buffer defined at the given instruction and index. - const LogicalBuffer* BufferAt(const HloInstruction* instruction, - const ShapeIndex& index) const { + const BufferValue* BufferAt(const HloInstruction* instruction, + const ShapeIndex& index) const { return points_to_analysis_->GetBufferDefinedAt(instruction, index) .ConsumeValueOrDie(); } @@ -150,8 +150,8 @@ class HeapSimulatorTracker { const ShapeIndex& index_a, const HloInstruction* instruction_b, const ShapeIndex& index_b) { - const LogicalBuffer* a = BufferAt(instruction_a, index_a); - const LogicalBuffer* b = BufferAt(instruction_b, index_b); + const BufferValue* a = BufferAt(instruction_a, index_a); + const BufferValue* b = BufferAt(instruction_b, index_b); EXPECT_EQ(result_.chunk_map[a].offset, result_.chunk_map[b].offset) << *a << ", " << *b; } @@ -525,7 +525,7 @@ TEST_F(HeapSimulatorTest, WholeModule) { // Now the final cond less-than buffer is allocated. {kAlloc, tracker.BufferAt(cond_lt, {})}, - // The order of the remaining Free calls is based on the LogicalBuffer.id, + // The order of the remaining Free calls is based on the BufferValue.id, // which is deterministic, but not obvious. {kFree, tracker.BufferAt(param, {})}, {kFree, tracker.BufferAt(param, {0})}, @@ -547,40 +547,40 @@ TEST_F(HeapSimulatorTest, WholeModule) { class HeapAlgorithmTestBase : public ::testing::Test { protected: HeapAlgorithmTestBase() : builder_("heap_simulator_test") { - buffer_a_ = DummyLogicalBuffer(); - buffer_b_ = DummyLogicalBuffer(); - buffer_c_ = DummyLogicalBuffer(); - buffer_d_ = DummyLogicalBuffer(); - buffer_e_ = DummyLogicalBuffer(); - buffer_f_ = DummyLogicalBuffer(); - buffer_g_ = DummyLogicalBuffer(); - buffer_h_ = DummyLogicalBuffer(); - buffer_i_ = DummyLogicalBuffer(); + buffer_a_ = DummyBufferValue(); + buffer_b_ = DummyBufferValue(); + buffer_c_ = DummyBufferValue(); + buffer_d_ = DummyBufferValue(); + buffer_e_ = DummyBufferValue(); + buffer_f_ = DummyBufferValue(); + buffer_g_ = DummyBufferValue(); + buffer_h_ = DummyBufferValue(); + buffer_i_ = DummyBufferValue(); } ~HeapAlgorithmTestBase() override {} - const LogicalBuffer* buffer_a_; - const LogicalBuffer* buffer_b_; - const LogicalBuffer* buffer_c_; - const LogicalBuffer* buffer_d_; - const LogicalBuffer* buffer_e_; - const LogicalBuffer* buffer_f_; - const LogicalBuffer* buffer_g_; - const LogicalBuffer* buffer_h_; - const LogicalBuffer* buffer_i_; + const BufferValue* buffer_a_; + const BufferValue* buffer_b_; + const BufferValue* buffer_c_; + const BufferValue* buffer_d_; + const BufferValue* buffer_e_; + const BufferValue* buffer_f_; + const BufferValue* buffer_g_; + const BufferValue* buffer_h_; + const BufferValue* buffer_i_; private: - // Create a dummy LogicalBuffer to pass to the heap algorithm. - const LogicalBuffer* DummyLogicalBuffer() { - const LogicalBuffer::Id id = buffers_.size(); + // Create a dummy BufferValue to pass to the heap algorithm. + const BufferValue* DummyBufferValue() { + const BufferValue::Id id = buffers_.size(); auto const0 = builder_.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - buffers_.emplace_back(MakeUnique(const0, ShapeIndex{}, id)); + buffers_.emplace_back(MakeUnique(id, const0, ShapeIndex{})); return buffers_.back().get(); } HloComputation::Builder builder_; - std::vector> buffers_; + std::vector> buffers_; }; class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {}; diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index aa6860880b7a1308d3ecabb52318daa7d2852af2..1f7c1cffd324ad2f4e4cdf11046c8459b8ceb6d5 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -147,6 +147,9 @@ message HloInstructionProto { repeated int64 called_computation_ids = 38; xla.OpSharding sharding = 40; + + // Backend configuration for the instruction. Has backend-specific meaning. + string backend_config = 43; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 594413e88fb26e86b198d08b2e4db77fad671348..63c3dc4a5932f754a9ccdd70d03c999fe528a448 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -347,6 +347,11 @@ std::list HloComputation::MakeEmbeddedComputationsList() // To avoid special handling of this computation, cast away const of // 'this'. 'this' is immediately removed from the post order after // construction. + // + // TODO(b/78350259): This violates const-correctness, since while the original + // computation is not returned, we still retrieve non-const computations from + // a const one. Consider also avoiding const for HloComputation, or review XLA + // for const-correctness of non-HloInstruction* types like this. ComputeComputationPostOrder(const_cast(this), &visited, &post_order); @@ -360,25 +365,38 @@ std::list HloComputation::MakeEmbeddedComputationsList() string HloComputation::ToString(const HloPrintOptions& options) const { std::ostringstream s; for (int i = 0; i < options.indent_amount(); i++) { - s << " "; + s << " "; } - if (options.print_percent()) { - s << "%"; + + if (!options.is_in_nested_computation()) { + if (options.print_percent()) { + s << "%"; + } + s << name() << " "; } - s << name(); + if (options.print_program_shape()) { - s << " " << ShapeUtil::HumanString(ComputeProgramShape()); - } - s << " {\n"; - for (const HloInstruction* instruction : MakeInstructionPostOrder()) { - for (int i = 0; i < options.indent_amount(); i++) { - s << " "; + s << ShapeUtil::HumanString(ComputeProgramShape()) << " "; + } + s << "{\n"; + { + // Print the instructions in this computation. + HloPrintOptions new_options = options; + new_options.set_indent_amount(options.indent_amount() + 1) + .set_is_in_nested_computation(true); + CanonicalNameMap name_map; + for (const HloInstruction* instruction : MakeInstructionPostOrder()) { + for (int i = 0; i < new_options.indent_amount(); i++) { + s << " "; + } + s << (instruction == root_instruction_ ? "ROOT " : "") + << instruction->ToStringWithCanonicalNameMap(new_options, &name_map) + << "\n"; } - s << " " << (instruction == root_instruction_ ? "ROOT " : "") - << instruction->ToString(options) << "\n"; } + for (int i = 0; i < options.indent_amount(); i++) { - s << " "; + s << " "; } s << "}"; return s.str(); @@ -402,27 +420,37 @@ HloComputationProto HloComputation::ToProto() const { /* static */ StatusOr> HloComputation::CreateFromProto( - HloModule* module, const HloComputationProto& proto, + const HloComputationProto& proto, const tensorflow::gtl::FlatMap& computation_map) { - std::vector> instructions; tensorflow::gtl::FlatMap instruction_map; + tensorflow::gtl::FlatMap to_proto_id; + std::vector> instructions; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { TF_ASSIGN_OR_RETURN( std::unique_ptr instruction, - HloInstruction::CreateFromProto(module, instruction_proto, - instruction_map, computation_map)); + HloInstruction::CreateFromProto(instruction_proto, instruction_map, + computation_map)); if (instruction->opcode() == HloOpcode::kParameter) { parameter_count++; } TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id())); instruction_map[instruction_proto.id()] = instruction.get(); + to_proto_id[instruction.get()] = instruction_proto.id(); instructions.push_back(std::move(instruction)); } TF_RET_CHECK(proto.root_id() != -1); TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id())); HloInstruction* root = instruction_map.at(proto.root_id()); + + // Sort the instructions in the proto id's order. + std::sort(instructions.begin(), instructions.end(), + [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); + return WrapUnique(new HloComputation(proto.name(), parameter_count, &instructions, root, /*fusion_instruction=*/nullptr)); @@ -723,18 +751,25 @@ Status HloComputation::Accept( return this->Accept(&visitor); } -std::unique_ptr HloComputation::Clone(const string& suffix, - HloModule* module) { +std::unique_ptr HloComputation::Clone( + const string& suffix, HloModule* module, + HloInstruction::CloneMap* clone_map) { return CloneWithReplacements( /*replacements=*/std::unordered_map>(), - module, suffix); + module, clone_map, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( std::unordered_map> replacements, - HloModule* module, const string& suffix) { + HloModule* module, HloInstruction::CloneMap* clone_map, + const string& suffix) { + HloInstruction::CloneMap local_clone_map; + if (clone_map == nullptr) { + clone_map = &local_clone_map; + } + // Look up instr in the replacements map, and return either the replacement, // or instr, if the replacement isn't present. // @@ -756,24 +791,19 @@ std::unique_ptr HloComputation::CloneWithReplacements( } } - std::unordered_map clone_map; std::vector> instructions; std::unique_ptr new_instr = nullptr; for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { auto replaced_operand = replace(operand); - // If replaced_operand is null, that means 'replacements' asked us not to - // include operand in the new computation. But we can't do that, because - // operand is used by instr. CHECK_NE(replaced_operand, nullptr) - << "replacements map tried to eliminate a used instruction " - << operand->ToString() << ", used by " << instr->ToString(); - new_operands.push_back(FindOrDie(clone_map, replaced_operand)); + << "Replacements map specifies to leave out " << operand->ToString() + << ", but it is used by " << instr->ToString() << "."; + new_operands.push_back(FindOrDie(*clone_map, replaced_operand)); } - new_instr = - instr->CloneWithNewOperands(instr->shape(), new_operands, module); - InsertOrDie(&clone_map, instr, new_instr.get()); + new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands, + module, clone_map); instructions.push_back(std::move(new_instr)); } Builder builder(name() + "." + suffix); @@ -781,27 +811,24 @@ std::unique_ptr HloComputation::CloneWithReplacements( builder.AddInstruction(std::move(instr)); } auto result = builder.Build( - /*root_instruction=*/FindOrDie(clone_map, replace(root_instruction()))); + /*root_instruction=*/FindOrDie(*clone_map, replace(root_instruction()))); // Clone control dependencies. for (auto instr : postorder) { - HloInstruction* new_instr = FindOrDie(clone_map, instr); + HloInstruction* new_instr = FindOrDie(*clone_map, instr); for (auto successor : instr->control_successors()) { auto replaced_successor = replace(successor); - - // successor may not be in clone_map, because it might have been - // removed by the replacements map. - if (replaced_successor == nullptr) { - continue; - } + CHECK_NE(replaced_successor, nullptr) + << "Replacements map specifies to leave out " << successor->ToString() + << ", but it is control-depended-on by " << instr->ToString() << "."; TF_CHECK_OK(new_instr->AddControlDependencyTo( - FindOrDie(clone_map, replaced_successor))); + FindOrDie(*clone_map, replaced_successor))); } } // We cloned the elements of 'replacements', so they're all going to be - // destroyed. HloInstructions need to be detached from their operands before + // destroyed. HloInstructions need to be detached from their operands before // they're destroyed, otherwise they stick around in the operands' users lists // and cause use-after-frees. for (auto& kv : replacements) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 9d3f6e9a2c2efd97681a22b6b0f6d929afc553de..8bc97df0365a32bdc89d4636ad4c7076ffb08296 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -49,9 +49,20 @@ class HloModule; // Describes a computation at the HLO level. // -// An HloComputation contains a directed acyclic graph of HLO instructions. The -// computation has a single root instruction which produces the output of the -// computation. +// You can think of an HloComputation like a function. It has some inputs +// (parameters) and returns exactly one value (the value of its root node). If +// you want to return multiple values, you can return a tuple. +// +// The instructions inside of a computation do not have an explicit total order. +// Instead, they have a partial order determined by their data and control +// dependencies. +// +// An HloModule contains one "entry computation" -- this is like main() in a C +// program. Every other computation inside of a module is attached to one or +// more HloInstructions, as a "nested computation". For example, the kMap +// instruction has a nested computation and "applies" it to every element of its +// input, elementwise. (That is, the input [x, y, z] is transformed to [f(x), +// f(y), f(z)].) class HloComputation { public: // Builder class for HloComputation. @@ -157,14 +168,12 @@ class HloComputation { // Creates a computation from the given proto. Arguments: // - // module: the module which will contain the computation. The newly created - // computation is *not* added to the module, however. // proto: the proto to convert from. // computation_map: a map from computation id to HloComputation*. This map // must contain all computations which the newly constructed computation // calls. static StatusOr> CreateFromProto( - HloModule* module, const HloComputationProto& proto, + const HloComputationProto& proto, const tensorflow::gtl::FlatMap& computation_map); // Gets the instructions in this computation. @@ -291,11 +300,17 @@ class HloComputation { const std::function& visitor_func) const; // Returns a deep copy of this computation including all instructions. - // If the module pointer is not nullptr, it will be the module where - // the cloned computations will be added to (in order to support deep - // cloning). - std::unique_ptr Clone(const string& suffix = "clone", - HloModule* module = nullptr); + // + // If the module pointer is not nullptr, then the cloned computations will be + // added to this module in order to support deep cloning. Otherwise the module + // of the computation is used. + // + // If clone_map is not nullptr, then each original instruction that is cloned + // will be inserted and map to its clone. clone_map should not already contain + // any of the instructions to clone. + std::unique_ptr Clone( + const string& suffix = "clone", HloModule* module = nullptr, + HloInstruction::CloneMap* clone_map = nullptr); // Like Clone(), but if an instruction is present in replacement_map, we use // the map's value to replace that instruction in the cloned computation. @@ -305,7 +320,9 @@ class HloComputation { std::unique_ptr CloneWithReplacements( std::unordered_map> replacements, - HloModule* module = nullptr, const string& suffix = "clone"); + HloModule* module = nullptr, + HloInstruction::CloneMap* clone_map = nullptr, + const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 7b7588f4ba9aa622677db6f9d5022cc8cc029e04..25469a54c48f4f5cab478aba929f1cc18de8b81f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -550,6 +550,108 @@ TEST_F(HloComputationTest, Reachability) { EXPECT_FALSE(reachability->IsReachable(constant2, copy)); } +TEST_F(HloComputationTest, Stringification) { + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto options = HloPrintOptions().set_print_metadata(false); + EXPECT_EQ(computation->ToString(options), + R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + %x = f32[5,10]{1,0} parameter(0) + %y = f32[20,10]{1,0} parameter(1) + %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} + ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); +} + +TEST_F(HloComputationTest, StringificationIndent) { + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto options = + HloPrintOptions().set_print_metadata(false).set_indent_amount(2); + EXPECT_EQ(computation->ToString(options), + R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + %x = f32[5,10]{1,0} parameter(0) + %y = f32[20,10]{1,0} parameter(1) + %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} + ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} + })"); +} + +TEST_F(HloComputationTest, StringificationCanonical) { + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto options = HloPrintOptions().set_print_metadata(false); + EXPECT_EQ(computation->ToString(options), + R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + %x = f32[5,10]{1,0} parameter(0) + %y = f32[20,10]{1,0} parameter(1) + %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} + ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); + + options = HloPrintOptions().Canonical(); + EXPECT_EQ(computation->ToString(options), R"(TransposeDot { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 7b552ee5b1798c4c7e24884a392c5982d7fb17ff..5d05ccfc0b223d8749a2577ba1bf96b1ab3e761b 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -149,7 +149,7 @@ TEST_F(HloConstantFoldingTest, Slice) { const int64 slice_limits[] = {10, 8, 6, 5, 9}; const int64 slice_strides[] = {1, 1, 1, 1, 1}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - LiteralTestUtil::CreateRandomLiteral( + Literal::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -172,7 +172,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { HloComputation::Builder builder(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - LiteralTestUtil::CreateRandomLiteral( + Literal::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); auto literal_clone = literal->Literal::CloneToUnique(); HloInstruction* literal_instruction = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 44e4f75f75b275653e1a07111943843fc6f78b33..94c9c7eabcc99d4cf61f535925c068a9b55ed136 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -142,19 +142,25 @@ Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) { } Status HloCostAnalysis::HandleParameter(const HloInstruction*) { + current_should_compute_bottleneck_time_ = false; current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } Status HloCostAnalysis::HandleConstant(const HloInstruction*) { + current_should_compute_bottleneck_time_ = false; current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) { // GetTupleElement forwards a pointer and does not touch each element in the // output. + current_should_compute_bottleneck_time_ = false; current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } @@ -329,6 +335,7 @@ Status HloCostAnalysis::HandleSelectAndScatter( Status HloCostAnalysis::HandleBitcast(const HloInstruction*) { // A bitcast does no computation and touches no memory. current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } @@ -555,11 +562,13 @@ Status HloCostAnalysis::HandleCall(const HloInstruction* call) { } Status HloCostAnalysis::HandleCustomCall(const HloInstruction*) { - // We can't do anything sane with CustomCalls, since we don't know what they - // do, and returning an error status will stop iteration over this - // computation, which is probably also not what we want. So just punt and - // return OK. This will cause all of the properties to be reported as 0, - // which is fine. + // Mark applicable fields as "unknown", since we don't know what CustomCall + // does. This is better than returning an error, which would stop iteration, + // and therefore would prevent us from getting *any* stats for a computation + // which contains a CustomCall. + current_properties_[kOptimalSecondsKey] = -1; + current_properties_[kBytesAccessedKey] = -1; + current_properties_[kFlopsKey] = -1; current_should_compute_bottleneck_time_ = false; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 9a89888480b8c79dfb1f79a50e9686bf45aa49b3..0fb65c845a6d4407c81171f6c1569fee98b1d16d 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -162,6 +162,17 @@ StatusOr MakeConcatHlo(ArraySlice operands, HloInstruction::CreateConcatenate(concat_shape, operands, dimension)); } +StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers) { + HloComputation* computation = lhs->parent(); + CHECK_EQ(computation, rhs->parent()); + TF_ASSIGN_OR_RETURN( + Shape dot_shape, + ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers)); + return computation->AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers)); +} + StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { CHECK_GT(n, 0); @@ -269,7 +280,7 @@ StatusOr BroadcastZeros( StatusOr> CreateComputationWithSignature( ArraySlice domain, const Shape& range, tensorflow::StringPiece name) { - HloComputation::Builder b(name.ToString()); + HloComputation::Builder b{std::string(name)}; int64 param_idx = 0; for (const Shape* param_shape : domain) { b.AddInstruction(HloInstruction::CreateParameter( diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index c9a7361a6af0c2a0839c59a0ea695ec1b9a98bd4..49b1402d689a74874e34423a1832a0b6aa15f469 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -97,6 +97,11 @@ StatusOr MakeGetTupleElementHlo(HloInstruction* operand, StatusOr MakeConcatHlo( tensorflow::gtl::ArraySlice operands, int64 dimension); +// Creates a Dot HLO instruction and adds it to the computation containing `lhs` +// and `rhs` (both must be in the same computation). +StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers); + // ----------------------------------------------------------------------------- // Some other miscellaneous helpers to generate common HLO patterns. All of // these add all the instructions they generate into the computation containing diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 3b22c93733af293e4d73a2b1b3ac8822dec6d5f5..c17c26c5a435fe34dd1024d596004cf6b5fdce8c 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" namespace xla { @@ -88,6 +89,20 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { return changed; } +// An instruction is considered to be equivalent to another only if they +// share the exact same set of operands. +int64 CseHash(const HloInstruction* instruction) { + int64 hash = std::hash()(static_cast(instruction->opcode())); + hash = tensorflow::Hash64Combine( + hash, instruction->opcode() == HloOpcode::kGetTupleElement + ? instruction->tuple_index() + : -1); + for (auto operand : instruction->operands()) { + hash = tensorflow::Hash64Combine(hash, operand->unique_id()); + } + return hash; +} + } // namespace StatusOr HloCSE::Run(HloModule* module) { @@ -95,7 +110,14 @@ StatusOr HloCSE::Run(HloModule* module) { const std::function eq_instructions = std::equal_to(); const std::function - eq_computations = std::equal_to(); + eq_computations = [](const HloComputation* lhs, + const HloComputation* rhs) { return *lhs == *rhs; }; + + auto cse_equal = [&](const HloInstruction* lhs, const HloInstruction* rhs) { + return lhs->Identical(*rhs, eq_instructions, eq_computations, + is_layout_sensitive_); + }; + for (auto* computation : module->computations()) { if (only_fusion_computations_ && !computation->IsFusionComputation()) { continue; @@ -103,13 +125,17 @@ StatusOr HloCSE::Run(HloModule* module) { changed |= CombineConstants(computation, is_layout_sensitive_); - std::list post_order = - computation->MakeInstructionPostOrder(); - std::set removed_instructions; - for (auto instruction : post_order) { - // If the instruction has already been removed by CSE skip over it. - if (removed_instructions.count(instruction) > 0 || - instruction->operand_count() == 0) { + // HLO instructions are grouped into equivalency classes by using the + // cse_equal predicate defined above. This set holds a representative + // instruction for each class. + tensorflow::gtl::FlatSet + representatives(/*N=*/1024, &CseHash, cse_equal); + + for (auto instruction : computation->MakeInstructionPostOrder()) { + // If the instruction has zero operands (constants, parameters, etc.) skip + // over it. + if (instruction->operand_count() == 0) { continue; } @@ -118,31 +144,16 @@ StatusOr HloCSE::Run(HloModule* module) { continue; } - // An instruction is considered to be equivalent to another only if they - // share the exact same set of operands. So to find equivalent - // instructions, we just search among instructions which share operand(0) - // of this instruction. - const HloInstruction* operand = instruction->operand(0); - - tensorflow::gtl::InlinedVector - equivalent_instructions; - for (HloInstruction* user : operand->users()) { - if (user != instruction && !user->HasSideEffect() && - user->Identical(*instruction, eq_instructions, eq_computations, - is_layout_sensitive_)) { - equivalent_instructions.push_back(user); - } - } - - // Replace all equivalent instructions with this instruction. - for (HloInstruction* equivalent_instruction : equivalent_instructions) { + auto it = representatives.find(instruction); + if (it != representatives.end()) { + HloInstruction* equivalent_instruction = *it; TF_RETURN_IF_ERROR( - equivalent_instruction->ReplaceAllUsesWith(instruction)); - TF_RETURN_IF_ERROR( - computation->RemoveInstruction(equivalent_instruction)); - removed_instructions.insert(equivalent_instruction); + instruction->ReplaceAllUsesWith(equivalent_instruction)); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); changed = true; + continue; } + representatives.insert(instruction); } } return changed; diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index df8853f34f6a72c52d1cde7332ada3809d2f3d96..9735764b692238d6a320bcff51e43b98dcadabda 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -72,7 +73,7 @@ TEST_F(HloCseTest, CombineTwoConstants) { auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = Literal::CreateR0(84.0); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { @@ -104,7 +105,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { @@ -134,7 +135,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, ConstantsSameValueDifferentType) { @@ -469,5 +470,36 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { EXPECT_THAT(root, op::Add(op::Map(op::Constant()), op::Map(op::Constant()))); } +TEST_F(HloCseTest, CompareComputations) { + auto module = tools::Parse(R"( + HloModule m + + add_computation { + add_lhs = f32[] parameter(0) + add_rhs = f32[] parameter(1) + ROOT add_root = f32[] add(add_lhs, add_rhs) + } + + add_computation2 { + add_lhs2 = f32[] parameter(0) + add_rhs2 = f32[] parameter(1) + ROOT add_root2 = f32[] add(add_lhs2, add_rhs2) + } + + ENTRY entry { + p = f32[10]{0} parameter(0) + c = f32[] constant(0) + r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation + r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2 + ROOT f2 = (f32[],f32[]) tuple(r1, r2) + })") + .ValueOrDie(); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->operand(0), root->operand(1)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 0c37a8d75f38dabaad886cc9d4adce8ab29ddf18..b06e6c9f3e62f375a9e48f8ef81efe7121bbef94 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -878,4 +878,128 @@ Status HloDataflowAnalysis::Verify() const { return Status::OK(); } +bool HloDataflowAnalysis::DoesNotUseOperandBuffer( + const HloInstruction* operand, const ShapeIndex& index, + const HloInstruction* user) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + // Find fusion parameter associated with 'operand'. + HloInstruction* fusion_param = + user->fused_parameter(user->operand_index(operand)); + // Iterate through all users of all uses of the fusion parameter value. + // Return false if any uses are detected, returns true otherwise. + const HloValue& value = GetValueDefinedAt(fusion_param, index); + return value.uses().empty(); + } else { + // Return false if no value at 'operand' and 'index' is used at 'user'. + for (const HloValue* value : GetValueSet(operand, index).values()) { + for (const HloUse& use : value->uses()) { + if (use.instruction == user) { + return false; + } + } + } + } + + return true; +} + +bool HloDataflowAnalysis::CanShareOperandBufferWithUser( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* user, const ShapeIndex& user_index) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + const Shape& operand_subshape = + ShapeUtil::GetSubshape(operand->shape(), operand_index); + const Shape& user_subshape = + ShapeUtil::GetSubshape(user->shape(), user_index); + // Check that operand and user emit the same shape and layout. + if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { + return false; + } + + if (user->opcode() == HloOpcode::kFusion) { + // Get the parameter associated with 'operand'; + HloInstruction* fusion_param = + user->fused_parameter(user->operand_index(operand)); + + const HloValue& value = GetValueDefinedAt(fusion_param, operand_index); + if (value.uses().size() != 1) { + return false; + } + const HloUse& use = value.uses()[0]; + + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && + user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return use.instruction == user->fused_expression_root() && + use.operand_number == 0; + } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + user->fused_expression_root()->opcode() == HloOpcode::kAdd) { + // Output fusion with kAdd fused root. + + // Check if one operand of kAdd fused root is kDot or kConvolution. + auto* add = user->fused_expression_root(); + auto add_operand_it = + std::find_if(add->operands().begin(), add->operands().end(), + [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot; + }); + if (add_operand_it == add->operands().end()) { + return false; + } + auto* matched_add_operand = *add_operand_it; + // Calculate operand index of 'add' operand which was not matched above. + const int64 other_add_operand_index = + matched_add_operand == add->operand(0) ? 1 : 0; + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root (at operand + // index 'other_add_operand_index'). + return use.instruction == user->fused_expression_root() && + use.operand_number == other_add_operand_index; + } + } + if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kWhile) { + // We eliminated other users in BufferLiveness::live_range_strictly_before, + // so here we just need to check that the use is at operand index 0. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && operand_indices[0] == 0; + } + if (user->opcode() == HloOpcode::kCall) { + // Get all uses of value defined by 'operand' at 'operand_index'. + const auto& uses = GetValueDefinedAt(operand, operand_index).uses(); + // Return true iff: + // *) There exists two uses of 'operand'. + // *) One use is by 'user' (caller). + // *) One use is by root instruction of called computation (callee root). + // (Note: we check the root of the called computation, because the + // root result buffer is required to alias with the Call result buffer). + // *) The root instruction of the called computation is element-wise on + // 'operand'. + const bool found_caller_use = + std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { + return use.instruction == user; + }) != uses.end(); + auto* callee_root = user->to_apply()->root_instruction(); + const bool found_elementwise_callee_use = + std::find_if( + uses.begin(), uses.end(), [callee_root](const HloUse& use) { + return use.instruction == callee_root && + callee_root->IsElementwiseOnOperand(use.operand_number); + }) != uses.end(); + return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; + } + // Check if 'user' is element-wise. + return user->IsElementwise(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 7b8a74b096ff48733717e78ada5bb56a28caed72..9868746b6113881949e388cd2a4aa9f610b1fdb7 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -118,6 +118,23 @@ class HloDataflowAnalysis { string ToString() const; + // Returns true if 'user' cannot possibly use the buffer at 'index' in + // 'operand'. Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user) const; + + // Returns true if 'user' (at 'user_index') can share a buffer with its + // operand 'operand' (at 'operand_index'). Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool CanShareOperandBufferWithUser(HloInstruction* operand, + const ShapeIndex& operand_index, + HloInstruction* user, + const ShapeIndex& user_index) const; + protected: HloDataflowAnalysis(const HloModule& module, bool ssa_form, bool bitcast_defines_value = false); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 07f69b8e1339fed636e4eb54791941b85e09fd17..5798326dcbf65c3c34748afb02afab1dc7af9147 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1873,5 +1873,346 @@ INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, HloDataflowAnalysisTest, ::testing::Values(false, true)); +class HloDataflowAnalysisTestBase : public HloTestBase { + protected: + void BuildModule(std::unique_ptr computation) { + module_ = CreateNewModule(); + computation_ = module_->AddEntryComputation(std::move(computation)); + } + + void RunAnalysis() { + CHECK_NOTNULL(module_.get()); + dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie(); + } + + void BuildModuleAndRunAnalysis(std::unique_ptr computation) { + BuildModule(std::move(computation)); + RunAnalysis(); + } + + std::unique_ptr module_; + HloComputation* computation_ = nullptr; + std::unique_ptr dataflow_analysis_; +}; + +class DoesNotUseOperandBufferTest : public HloDataflowAnalysisTestBase {}; + +TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { + auto builder = HloComputation::Builder(TestName()); + + Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1)); + builder.AddInstruction( + HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // GetTupleElement instructions only access the top-level buffer of their + // operand. + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, gte0)); + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, gte1)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte0)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte1)); +} + +TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction never uses tuple element 0, but does use element 1. + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); +} + +class CanShareOperandBufferWithUserTest : public HloDataflowAnalysisTestBase {}; + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape in_shape = ShapeUtil::MakeShape(F32, {8}); + Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, in_shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, in_shape, "param1")); + auto result = builder.AddInstruction( + HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + result, {})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + result, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, copy, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction can share with tuple element 1. + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {0}, + fusion, {})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {1}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {4}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto update = builder.AddInstruction( + HloInstruction::CreateParameter(1, update_shape, "update")); + auto starts = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "starts")); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, data, update, starts)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The DynamicUpdateSlice instruction can share with the data operand, but not + // with update or starts. + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {})); + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {})); + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto a = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + auto b = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto add_operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, dot, add_operand)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, dot}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused dot add should be able to share buffer with 'add_operand'. + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(add_operand, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape, operand, {0, 1})); + + auto two = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, two, reverse}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused operand->reverse->add cannot alias operand buffer 'operand'. + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + + auto make_cond = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Cond"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); + return builder.Build(); + }; + + auto make_body = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data)); + return builder.Build(); + }; + + module_ = CreateNewModule(); + HloComputation* cond_computation = + module_->AddEmbeddedComputation(make_cond()); + HloComputation* body_computation = + module_->AddEmbeddedComputation(make_body()); + + auto builder = HloComputation::Builder(TestName()); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto whil = builder.AddInstruction(HloInstruction::CreateWhile( + data_shape, cond_computation, body_computation, data)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + // The While instruction can share with the data operand. + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, whil, {})); +} + +// Tests that Call can alias operand buffer if the only use of the operand +// in the called computation is an elementwise instruction. +TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { + Shape shape = ShapeUtil::MakeShape(F32, {8}); + // Build sub-computation with fusion root. + auto sub_builder = HloComputation::Builder(TestName() + "_sub"); + auto sub_param = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "sub_param")); + auto one = sub_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto ones = sub_builder.AddInstruction( + HloInstruction::CreateBroadcast(shape, one, {1})); + auto add = sub_builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); + + module_ = CreateNewModule(); + auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); + sub_computation->CreateFusionInstruction({add, ones}, + HloInstruction::FusionKind::kLoop); + + // Build entry-computation with kCall which calls 'sub_computation'. + auto builder = HloComputation::Builder(TestName()); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto reverse = + builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0})); + auto call = builder.AddInstruction( + HloInstruction::CreateCall(shape, {reverse}, sub_computation)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(reverse, {}, call, {})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 1071f5b184bd77cb9c57fa3fc28d4005711369b5..fa59a5fb2030b22aa9e6a59abbfba521d19adb51 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" @@ -42,7 +43,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -52,25 +52,11 @@ namespace xla { namespace { using tensorflow::gtl::ArraySlice; -using tensorflow::gtl::FlatSet; -using tensorflow::gtl::optional; - -template -struct is_complex_t : public std::false_type {}; - -template <> -struct is_complex_t : public std::true_type {}; - -template -struct is_complex64_t : public std::false_type {}; - -template <> -struct is_complex64_t : public std::true_type {}; template StatusOr> Compare(const Shape& shape, HloOpcode opcode, - const Literal& lhs_literal, - const Literal& rhs_literal) { + LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -108,7 +94,7 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, << HloOpcodeString(opcode); } - auto result = Literal::CreateFromShape(shape); + auto result = MakeUnique(shape); TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -119,8 +105,8 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, template <> StatusOr> Compare( - const Shape& shape, HloOpcode opcode, const Literal& lhs_literal, - const Literal& rhs_literal) { + const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -138,7 +124,7 @@ StatusOr> Compare( << HloOpcodeString(opcode); } - auto result = Literal::CreateFromShape(shape); + auto result = MakeUnique(shape); TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -147,2092 +133,48 @@ StatusOr> Compare( return std::move(result); } -template -StatusOr> ElementWiseUnaryOpImpl( - HloInstruction* instruction, - const std::function& unary_op, - const Literal& operand_literal) { - const auto shape = instruction->shape(); - const auto* operand = instruction->operand(0); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. - if (!ShapeUtil::SameDimensions(shape, operand->shape())) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(operand->shape()).c_str()); - } - - auto result = Literal::CreateFromShape(shape); - - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice multi_index) { - return unary_op(operand_literal.Get(multi_index)); - })); - return std::move(result); -} - -// For one particular placement of a window in a base shape (the placement is -// represented as `window_count_index`), iterates inside the window. Translates -// the window index into base index. If the base index is within bound, call `f` -// with the base index. -void IterateThroughWindow( - const Shape& window_shape, const Window& window, const Shape& base_shape, - const ArraySlice& window_count_index, - const std::function&)>& f) { - const int64 rank = ShapeUtil::Rank(base_shape); - DimensionVector window_index(rank); - std::fill(window_index.begin(), window_index.end(), 0); - do { - std::vector base_index(rank); - bool out_of_bound = false; - for (int64 i = 0; i < rank; ++i) { - base_index[i] = window_count_index[i] * window.dimensions(i).stride() + - window_index[i] - window.dimensions(i).padding_low(); - if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { - out_of_bound = true; - break; - } - } - if (!out_of_bound) { - f(base_index); - } - } while (IndexUtil::BumpIndices(window_shape, &window_index)); -} - -// Creates a vector of multipliers which can be used to create a linear index -// into shape. -// -// Given the multidimensional index {i1, ..., iN} and -// M = MakeDimMultipliers(shape), the corresponding linear index LI is simply -// -// LI = i1 * M[1] + i2 * M[2] + ... + iN * M[N]. -// -// This lets you calculate LI given the multidimensional indices in any order. -DimensionVector MakeDimMultipliers(const Shape& shape) { - DimensionVector v(ShapeUtil::Rank(shape)); - int64 scale = 1; - for (auto dim : LayoutUtil::MinorToMajor(shape)) { - v[dim] = scale; - scale *= shape.dimensions(dim); - } - return v; -} - } // namespace -template -class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { - public: - explicit TypedVisitor(HloEvaluator* p) : parent_(p) {} - - // The following higher-order functions convert a function with ElementwiseT - // to a function with ReturnT. - std::function ConvertUnaryFunction( - const std::function& unary_op) { - return [&unary_op](ReturnT arg) { - return static_cast(unary_op(static_cast(arg))); - }; - } - std::function ConvertBinaryFunction( - const std::function& - binary_op) { - return [&binary_op](ReturnT arg1, ReturnT arg2) { - return static_cast(binary_op(static_cast(arg1), - static_cast(arg2))); - }; - } - std::function ConvertTernaryFunction( - const std::function& ternary_op) { - return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) { - return static_cast(ternary_op(static_cast(arg1), - static_cast(arg2), - static_cast(arg3))); - }; - } - - Status DefaultAction(HloInstruction* hlo_instruction) override { - return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", - HloOpcodeString(hlo_instruction->opcode()).c_str()); - } - - // TODO(b/35950897): many of the stl functions used in the handlers are not - // overloaded for every XLA primitive types. - - template ::value>::type* = - nullptr> - Status HandleAbs(HloInstruction* abs) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], - ElementWiseUnaryOp(abs, [](NativeT elem_operand) { - return elem_operand; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleAbs(HloInstruction* abs) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], - ElementWiseUnaryOp(abs, [](NativeT elem_operand) { - return std::abs(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleAbs(HloInstruction* abs) { - const Literal& operand_literal = - parent_->GetEvaluatedLiteralFor(abs->operand(0)); - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[abs], - (ElementWiseUnaryOpImpl( - abs, [](NativeT elem_operand) { return std::abs(elem_operand); }, - operand_literal))); - - return Status::OK(); - } - - Status HandleAbs(HloInstruction* abs) override { - // If the operand is of C64 type, the return type of abs will be F32. - // However, ElementwiseT would still be the return type, F32, and thus - // specifying the ElementwiseT explicitly as C64 is needed below. - if (abs->operand(0)->shape().element_type() == C64) { - return HandleAbs(abs); - } - return HandleAbs(abs); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleRound(HloInstruction* round) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[round], - ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) { - return std::round(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleRound(HloInstruction* round) { - return InvalidArgument("Unsupported type for Round"); - } - - Status HandleRound(HloInstruction* round) override { - return HandleRound(round); - } - - Status HandleBroadcast(HloInstruction* broadcast) override { - parent_->evaluated_[broadcast] = - Literal::CreateFromShape(broadcast->shape()); - auto output = parent_->evaluated_[broadcast].get(); - const Literal& operand_to_broadcast = - parent_->GetEvaluatedLiteralFor(broadcast->operand(0)); - std::vector broadcast_indices( - ShapeUtil::Rank(broadcast->operand(0)->shape()), 0); - - TF_RET_CHECK(broadcast->dimensions().size() == - ShapeUtil::Rank(operand_to_broadcast.shape())) - << "broadcast dimensions is of size: " << broadcast->dimensions().size() - << " and rank of operand_to_broadcast is: " - << ShapeUtil::Rank(operand_to_broadcast.shape()); - // Checks that operand's dimensions are the same as the broadcast's - // dimensions along the dimensions to be broadcasted. - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == - operand_to_broadcast.shape().dimensions(i)); - } - - return output->Populate([&](ArraySlice multi_index) { - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; - } - return operand_to_broadcast.Get(broadcast_indices); - }); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleCeil(HloInstruction* ceil) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], - ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) { - return std::ceil(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleCeil(HloInstruction* ceil) { - return InvalidArgument("Unsupported type for Ceil"); - } - - Status HandleCeil(HloInstruction* ceil) override { - return HandleCeil(ceil); - } - - Status HandleConvert(HloInstruction* convert) override { - const HloInstruction* operand = convert->operand(0); - TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr result, - parent_->GetEvaluatedLiteralFor(operand).Convert( - convert->shape().element_type())); - - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } - return Status::OK(); - } - - Status HandleBitcastConvert(HloInstruction* convert) override { - const HloInstruction* operand = convert->operand(0); - TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr result, - parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( - convert->shape().element_type())); - - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } - return Status::OK(); - } - - Status HandleExp(HloInstruction* exp) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], - ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) { - return std::exp(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleFloor(HloInstruction* floor) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[floor], - ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) { - return std::floor(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleFloor(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Floor"); - } - - Status HandleFloor(HloInstruction* floor) override { - return HandleFloor(floor); - } - - Status HandleLog(HloInstruction* log) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], - ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { - return std::log(elem_operand); - })); - return Status::OK(); - } - - template ::value && - !std::is_same::value>::type* = nullptr> - Status HandleNot(HloInstruction* not_) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { - return ~elem_operand; - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleNot(HloInstruction* not_) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { - return !elem_operand; - })); - return Status::OK(); - } - - template ::value>::type* = - nullptr> - Status HandleNot(HloInstruction* not_) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { - return !elem_operand; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleNot(HloInstruction* not_) { - return InvalidArgument("Unsupported type for Not"); - } - - Status HandleNot(HloInstruction* not_) override { - return HandleNot(not_); - } - - template ::value && - !std::is_floating_point::value>::type* = nullptr> - Status HandleNegate(HloInstruction* negate) { - using type = typename std::make_unsigned::type; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[negate], - ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) { - return NativeT(-type(elem_operand)); - })); - return Status::OK(); - } - - template ::value || - std::is_floating_point::value>::type* = nullptr> - Status HandleNegate(HloInstruction* negate) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[negate], - ElementWiseUnaryOp( - negate, [](ElementwiseT elem_operand) { return -elem_operand; })); - return Status::OK(); - } - - Status HandleNegate(HloInstruction* negate) override { - return HandleNegate(negate); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleSign(HloInstruction* sign) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], - ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { - return (ElementwiseT(0) < elem_operand) - - (elem_operand < ElementwiseT(0)); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleSign(HloInstruction* sign) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], - ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { - auto abs_val = std::abs(elem_operand); - return 0 == abs_val ? ElementwiseT(0) - : elem_operand / abs_val; - })); - return Status::OK(); - } - - Status HandleSign(HloInstruction* sign) override { - return HandleSign(sign); - } - - template ::value>::type* = nullptr> - Status HandleAtan2(HloInstruction* atan2) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2], - ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem, - ElementwiseT rhs_elem) { - return std::atan2(lhs_elem, rhs_elem); - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleAtan2(HloInstruction* atan2) { - return InvalidArgument("Unsupported type for Atan2"); - } - - Status HandleAtan2(HloInstruction* atan2) override { - return HandleAtan2(atan2); - } - - Status HandleTanh(HloInstruction* tanh) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], - ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) { - return std::tanh(elem_operand); - })); - return Status::OK(); - } - - template ::value && - !std::is_floating_point::value>::type* = nullptr> - Status HandleMultiply(HloInstruction* multiply) { - using type = typename std::make_unsigned::type; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return NativeT(type(lhs_elem) * type(rhs_elem)); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value || - std::is_floating_point::value || - is_complex_t::value>::type* = nullptr> - Status HandleMultiply(HloInstruction* multiply) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return lhs_elem * rhs_elem; - })); - return Status::OK(); - } - - Status HandleMultiply(HloInstruction* multiply) override { - return HandleMultiply(multiply); - } - - Status HandleSubtract(HloInstruction* subtract) override { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[subtract], - ElementWiseBinaryOp(subtract, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return lhs_elem - rhs_elem; - })); - return Status::OK(); - } - - Status HandleAdd(HloInstruction* add) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[add], - ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem, - ElementwiseT rhs_elem) { - return lhs_elem + rhs_elem; - })); - return Status::OK(); - } - - Status HandleDivide(HloInstruction* divide) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], - ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, - ElementwiseT rhs_elem) { - return lhs_elem / rhs_elem; - })); - return Status::OK(); - } - - template ::value>::type* = - nullptr> - Status HandleMaximum(HloInstruction* maximum) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[maximum], - ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { - return std::max(lhs, rhs); - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleMaximum(HloInstruction* maximum) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[maximum], - ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { - return ((lhs >= rhs) || std::isnan(lhs)) ? lhs : rhs; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleMaximum(HloInstruction* maximum) { - return InvalidArgument("Unsupported type for Maximum"); - } - - Status HandleMaximum(HloInstruction* maximum) override { - return HandleMaximum(maximum); - } - - template ::value>::type* = - nullptr> - Status HandleMinimum(HloInstruction* minimum) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum], - ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return std::min(lhs_el, rhs_el); - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleMinimum(HloInstruction* minimum) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[minimum], - ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return ((lhs_el <= rhs_el) || std::isnan(lhs_el)) ? lhs_el : rhs_el; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleMinimum(HloInstruction* minimum) { - return InvalidArgument("Unsupported type for Minimum"); - } - - Status HandleMinimum(HloInstruction* minimum) override { - return HandleMinimum(minimum); - } - - Status HandlePower(HloInstruction* power) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[power], - ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return std::pow(lhs_el, rhs_el); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleRemainder(HloInstruction* remainder) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], - ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return std::fmod(lhs_el, rhs_el); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleRemainder(HloInstruction* remainder) { - return InvalidArgument("Unsupported type for Remainder"); - } - - Status HandleRemainder(HloInstruction* remainder) override { - return HandleRemainder(remainder); - } - - template ::value>::type* = - nullptr> - Status HandleAnd(HloInstruction* and_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[and_], - ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el & rhs_el; - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleAnd(HloInstruction* and_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[and_], - ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el && rhs_el; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleAnd(HloInstruction* and_) { - return InvalidArgument("Unsupported type for And"); - } - - Status HandleAnd(HloInstruction* and_) override { - return HandleAnd(and_); - } - - template ::value>::type* = - nullptr> - Status HandleOr(HloInstruction* or_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[or_], - ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el | rhs_el; - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleOr(HloInstruction* or_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[or_], - ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el || rhs_el; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleOr(HloInstruction* or_) { - return InvalidArgument("Unsupported type for Or"); - } - - Status HandleOr(HloInstruction* or_) override { - return HandleOr(or_); - } - - template ::value && - !std::is_same::value>::type* = nullptr> - Status HandleShiftLeft(HloInstruction* shl) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[shl], - ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) { - return IsShiftOutOfBounds(rhs_elem) ? 0 - : (lhs_elem << rhs_elem); - })); - return Status::OK(); - } - - template ::value || - std::is_same::value>::type* = - nullptr> - Status HandleShiftLeft(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftLeft"); - } - - Status HandleShiftLeft(HloInstruction* shl) override { - return HandleShiftLeft(shl); - } - template ::value && - !std::is_same::value>::type* = nullptr> - Status HandleShiftRightArithmetic(HloInstruction* shr) { - typedef typename std::make_signed::type SignedT; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[shr], - ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { - SignedT lhs_signed = static_cast(lhs_elem); - if (IsShiftOutOfBounds(rhs_elem)) { - return lhs_signed < 0 ? static_cast(-1) : 0; - } else { - return lhs_signed >> rhs_elem; - } - })); - return Status::OK(); - } - - template ::value || - std::is_same::value>::type* = - nullptr> - Status HandleShiftRightArithmetic(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightArithmetic"); - } - - Status HandleShiftRightArithmetic(HloInstruction* shra) override { - return HandleShiftRightArithmetic(shra); - } - - template ::value && - !std::is_same::value>::type* = nullptr> - Status HandleShiftRightLogical(HloInstruction* shr) { - typedef typename std::make_unsigned::type UnsignedT; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[shr], - ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { - // If shift amount is greater than the number of bits, then return 0. - if (IsShiftOutOfBounds(rhs_elem)) { - return static_cast(0); - } - return static_cast(static_cast(lhs_elem) >> - rhs_elem); - })); - return Status::OK(); - } - - template ::value || - std::is_same::value>::type* = - nullptr> - Status HandleShiftRightLogical(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightLogical"); - } - - Status HandleShiftRightLogical(HloInstruction* shrl) override { - return HandleShiftRightLogical(shrl); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleClamp(HloInstruction* clamp) { - std::function - clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { - return std::fmin(high, std::fmax(value, low)); - }; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[clamp], - ElementwiseTernaryOp(clamp, - std::move(ConvertTernaryFunction(clamp_op)))); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleClamp(HloInstruction*) { - return InvalidArgument("Unsupported type for Clamp"); - } - - Status HandleClamp(HloInstruction* clamp) override { - return HandleClamp(clamp); - } - - Status HandleSelect(HloInstruction* select) override { - CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); - CHECK(!ShapeUtil::IsTuple(select->shape())); - std::function select_op = - [](bool pred, ReturnT on_true, ReturnT on_false) { - if (pred) { - return on_true; - } - return on_false; - }; - TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], - ElementwiseTernaryOp(select, std::move(select_op))); - return Status::OK(); - } - - Status HandleReverse(HloInstruction* reverse) override { - const auto result_shape = reverse->shape(); - const auto reverse_dimensions = reverse->dimensions(); - - auto operand = reverse->operand(0); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferReverseShape(operand->shape(), - reverse_dimensions)); - - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape set to: " << ShapeUtil::HumanString(result_shape) - << " but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto result = Literal::CreateFromShape(result_shape); - - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice out_index) { - std::vector from_index(out_index.begin(), out_index.end()); - for (const int64 dim : reverse_dimensions) { - from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; - } - return operand_literal.Get(from_index); - })); - - parent_->evaluated_[reverse] = std::move(result); - return Status::OK(); - } - - Status HandleConvolution(HloInstruction* conv) override { - auto lhs = conv->operand(0); - auto rhs = conv->operand(1); - const auto& window = conv->window(); - const Shape& result_shape = conv->shape(); - const Shape& lhs_shape = lhs->shape(); - const Shape& rhs_shape = rhs->shape(); - - TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape)); - TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape)); - CHECK(ShapeUtil::IsArray(lhs_shape)); - CHECK(ShapeUtil::IsArray(rhs_shape)); - CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape)); - CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape)); - - const auto& dnums = conv->convolution_dimension_numbers(); - const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); - CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size()); - CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); - CHECK_GE(num_spatial_dims, 0); - CHECK_EQ(window.dimensions_size(), num_spatial_dims); - - const auto lhs_rank = ShapeUtil::Rank(lhs_shape); - const auto rhs_rank = ShapeUtil::Rank(rhs_shape); - - CHECK_EQ(num_spatial_dims + 2, lhs_rank); - CHECK_EQ(num_spatial_dims + 2, rhs_rank); - - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, - window, dnums)); - CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape set to: " << ShapeUtil::HumanString(result_shape) - << " but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - - std::vector window_dimension_sizes; - for (auto i : dnums.kernel_spatial_dimensions()) { - window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i)); - } - - const Shape& window_shape = - ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes); - - DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape); - DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape); - - auto lhs_literal_data = lhs_literal.data(); - auto rhs_literal_data = rhs_literal.data(); - - auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, - &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, - rhs_literal_data](ArraySlice out_index) { - // Dimension number applicable for input (lhs). - const int64 input_batch_dim = dnums.input_batch_dimension(); - const int64 input_z_dim = dnums.input_feature_dimension(); - // Dimension number applicable for kernel (rhs). - const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); - const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); - // Dimension number applicable for output. - const int64 output_batch_dim = dnums.output_batch_dimension(); - const int64 output_z_dim = dnums.output_feature_dimension(); - - const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); - - ElementwiseT result_val = static_cast(0); - DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), - 0); - - // Convolve input feature with kernel. - do { - for (int64 iz = 0; iz < z_size; ++iz) { - int64 lhs_linear_index = 0; - lhs_linear_index += out_index[output_batch_dim] * - lhs_dim_multipliers[input_batch_dim]; - lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; - - int64 rhs_linear_index = 0; - rhs_linear_index += out_index[output_z_dim] * - rhs_dim_multipliers[kernel_output_z_dim]; - rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim]; - - // Find corresponding spatial dimension index for input (lhs). - for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { - // Spatial dimension number for input (lhs) and output. - const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); - const int64 output_spatial_dim = - dnums.output_spatial_dimensions(ki); - - // Calculate lhs (input) index without taking base dilation into - // account. - const auto& window_dim = window.dimensions(ki); - const int64 undilated_index = - out_index[output_spatial_dim] * window_dim.stride() - - window_dim.padding_low() + - rhs_spatial_index[ki] * window_dim.window_dilation(); - // Skip if the lhs (input) index is to be dilated. As an - // optimization, skip this mod if there's no dilation. - if (window_dim.base_dilation() > 1 && - undilated_index % window_dim.base_dilation() != 0) { - goto cnt; - } - - // Calculate the actual lhs (input) index after dilation. As an - // optimization, skip this integer divide if there's no dilation. - int64 lhs_spatial_index; - if (window_dim.base_dilation() > 1) { - lhs_spatial_index = undilated_index / window_dim.base_dilation(); - } else { - lhs_spatial_index = undilated_index; - } - lhs_linear_index += - lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim]; - - // Skip if input index is not in bounds. - if (!(lhs_spatial_index >= 0 && - lhs_spatial_index < - lhs_shape.dimensions(input_spatial_dim))) { - goto cnt; - } - - rhs_linear_index += - (window_dim.window_reversal() - ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) - : rhs_spatial_index[ki]) * - rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)]; - } - - result_val += - static_cast(lhs_literal_data[lhs_linear_index]) * - static_cast(rhs_literal_data[rhs_linear_index]); - } - cnt : {} - } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); - - return static_cast(result_val); - }; - - auto result = Literal::CreateFromShape(result_shape); - TF_RETURN_IF_ERROR(result->PopulateParallel(func)); - - parent_->evaluated_[conv] = std::move(result); - return Status::OK(); - } - - Status HandleDot(HloInstruction* dot) override { - auto lhs = dot->operand(0); - auto rhs = dot->operand(1); - CHECK(ShapeUtil::IsArray(dot->shape())); - CHECK(ShapeUtil::IsArray(lhs->shape())); - CHECK(ShapeUtil::IsArray(rhs->shape())); - - const auto& dnums = dot->dot_dimension_numbers(); - - const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); - const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); - - CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); - CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); - - // There must be 1 and only 1 Contracting dimension for lhs and rhs. - CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); - CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); - const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); - const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); - // Contracted dimension sizes must be the same. - CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension), - rhs->shape().dimensions(rhs_contracting_dimension)) - << "lhs contracted dimension: " - << lhs->shape().dimensions(lhs_contracting_dimension) - << " rhs contracted dimension: " - << rhs->shape().dimensions(rhs_contracting_dimension); - const int64 contracted_dimension_size = - lhs->shape().dimensions(lhs_contracting_dimension); - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - - auto result = Literal::CreateFromShape(dot->shape()); - - CHECK_EQ(dnums.lhs_batch_dimensions_size(), - dnums.rhs_batch_dimensions_size()); - - std::vector lhs_non_contracting_dims; - for (int64 i = 0; i < lhs_rank; i++) { - if (i != lhs_contracting_dimension) { - lhs_non_contracting_dims.push_back(i); - } - } - - std::vector rhs_non_batch_non_contracting_dims; - FlatSet batch_dims_set(dnums.rhs_batch_dimensions().begin(), - dnums.rhs_batch_dimensions().end()); - for (int64 i = 0; i < rhs_rank; i++) { - if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { - rhs_non_batch_non_contracting_dims.push_back(i); - } - } - - const int64 batch_dim_size = dnums.lhs_batch_dimensions_size(); - const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size(); - - DimensionVector lhs_index(lhs_rank); - DimensionVector rhs_index(rhs_rank); - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice result_index) { - ElementwiseT result_val = static_cast(0); - - // Find the corresponding non-contracting indices for lhs and rhs. - // - // For `result_index`, its batch dimension, if exists, will be at the - // same dimension as the batch dimension of lhs and rhs. More - // specifically: - // - For lhs, the non-contracting dimensions, including the batch - // dimension have the same index as the `result_index`. - // - For rhs, the batch dimension is set separately from other - // non-contracting dimensions, since these other non-contracting - // dimensions in rhs follow the non-contracting dimensions of lhs in - // the resulting index. - // - // As an example, for a resulting index: - // result_index [result_batch, result_x, result_y] - // the effecting lhs and rhs indices are: - // lhs [result_batch, lhs_non_contracting_dim, contracting_dim - // rhs [result_batch, contracting_dim, rhs_non_contracting_dim] - // `result_x` is only affected by the lhs_non_contracting_dim and - // likewise `result_y` only depends on rhs_non_contracting_dim. - // - // so we can look up the lhs and rhs indices by: - // - // lhs: - // batch index is the same as `result_batch`. - // non-contracting dimension is the same as - // result_index[lhs_non_contracting_dim] - // rhs: - // batch index: the same as `result_batch`. - // non-contracting dimension index: *not* the same as - // result_index[rhs_non_contractng_dim], since the - // non-contracting dimensions of lhs are included in the - // result_index first. Instead, the non_contracting_dim of rhs must - // be calculated as following: - // lhs_non_contracting_dimensions_size + - // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1 - // - // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is - // the index offset to the result_index that only depends on - // the non_batch and non-contracting dimensions of rhs. -1 at the - // end translates size to index. - for (auto i : lhs_non_contracting_dims) { - lhs_index[i] = result_index[i]; - } - for (auto i : dnums.rhs_batch_dimensions()) { - rhs_index[i] = result_index[i]; - } - for (auto i : rhs_non_batch_non_contracting_dims) { - const int64 rhs_non_batch_non_contracting_dim = - lhs_non_contracting_size + (i - batch_dim_size) - 1; - rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim]; - } - - // Accumulates resulting product along the contracted dimension. - for (int64 i = 0; i < contracted_dimension_size; ++i) { - lhs_index[lhs_contracting_dimension] = i; - rhs_index[rhs_contracting_dimension] = i; - - result_val += - static_cast(lhs_literal.Get(lhs_index)) * - static_cast(rhs_literal.Get(rhs_index)); - } - - return static_cast(result_val); - })); - - parent_->evaluated_[dot] = std::move(result); - return Status::OK(); - } - - Status HandlePad(HloInstruction* pad) override { - CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape())); - // Padding value must be scalar. - CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); - CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), - pad->padding_config().dimensions_size()); - - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferPadShape( - /*operand_shape=*/pad->operand(0)->shape(), - /*padding_value_shape=*/pad->operand(1)->shape(), - /*padding_config=*/pad->padding_config())); - CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(pad->shape()) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - // Create new HLO of padded shape with padding value. - ReturnT scalar = - parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); - auto result = Literal::CreateFromShape(pad->shape()); - TF_RETURN_IF_ERROR(result->Populate( - [&scalar](ArraySlice multi_index) { return scalar; })); - - const Literal& evaluated_operand = - parent_->GetEvaluatedLiteralFor(pad->operand(0)); - - std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), - 0); - std::vector target_index(ShapeUtil::Rank(result->shape()), 0); - - // Loop through each element of the operand, assign them to the - // corresponding index of the resulting padded literal. - const PaddingConfig& pad_config = pad->padding_config(); - - auto func = [&](ArraySlice input_index) { - for (auto i = 0; i < input_index.size(); ++i) { - // Interior padding occurs logically before edge padding, so in the case - // of negative edge padding elements are removed from the - // interior-padded operand. - target_index[i] = - pad_config.dimensions(i).edge_padding_low() + - input_index[i] * (pad_config.dimensions(i).interior_padding() + 1); - - // Account for negative low and high padding: skip assignment if the - // any target index is out of range. - if (!(target_index[i] >= 0 && - target_index[i] < pad->shape().dimensions(i))) { - return true; - } - } - result->Set(target_index, - evaluated_operand.Get(input_index)); - return true; - }; - - std::vector zero_base(evaluated_operand.shape().dimensions_size(), - 0); - std::vector step(evaluated_operand.shape().dimensions_size(), 1); - - ShapeUtil::ForEachIndex( - evaluated_operand.shape(), zero_base, - AsInt64Slice(evaluated_operand.shape().dimensions()), step, func); - - parent_->evaluated_[pad] = std::move(result); - return Status::OK(); - } - - Status HandleDynamicSlice(HloInstruction* dynamic_slice) override { - auto operand = dynamic_slice->operand(0); - auto start_indices = dynamic_slice->operand(1); - auto result_shape = dynamic_slice->shape(); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferDynamicSliceShape( - operand->shape(), start_indices->shape(), - dynamic_slice->dynamic_slice_sizes())); - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(result_shape) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - TF_RET_CHECK( - primitive_util::IsIntegralType(start_indices->shape().element_type())); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - const Literal& start_indices_literal = - parent_->GetEvaluatedLiteralFor(start_indices); - - switch (start_indices->shape().element_type()) { - case S32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); - } break; - case S64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); - } break; - case U32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); - } break; - case U64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); - } break; - default: - LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for " - "start_indices: " - << PrimitiveType_Name(start_indices->shape().element_type()); - } - - return Status::OK(); - } - - Status HandleDynamicUpdateSlice( - HloInstruction* dynamic_update_slice) override { - auto operand = dynamic_update_slice->operand(0); - auto update = dynamic_update_slice->operand(1); - auto start_indices = dynamic_update_slice->operand(2); - auto result_shape = dynamic_update_slice->shape(); - TF_ASSIGN_OR_RETURN( - auto inferred_return_shape, - ShapeInference::InferDynamicUpdateSliceShape( - operand->shape(), update->shape(), start_indices->shape())); - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(result_shape) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - TF_RET_CHECK( - primitive_util::IsIntegralType(start_indices->shape().element_type())); - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape())); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update); - const Literal& start_indices_literal = - parent_->GetEvaluatedLiteralFor(start_indices); - - switch (start_indices->shape().element_type()) { - case S32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); - } break; - case S64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); - } break; - case U32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); - } break; - case U64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); - } break; - default: - LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for " - "start_indices: " - << PrimitiveType_Name(start_indices->shape().element_type()); - } - - return Status::OK(); - } - - template - StatusOr> MapImpl(HloInstruction* map) { - auto operands = map->operands(); - HloComputation* computation = map->to_apply(); - - auto result = Literal::CreateFromShape(map->shape()); - - HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice multi_index) { - std::vector> arg_literals; - arg_literals.reserve(operands.size()); - - // Construct scalar literal parameters to be passed to the map - // computation. - for (auto operand : operands) { - const Literal& arg_literal = - parent_->GetEvaluatedLiteralFor(operand); - - auto curr_val = arg_literal.Get(multi_index); - auto curr_val_literal = Literal::CreateR0(curr_val); - - arg_literals.push_back(std::move(curr_val_literal)); - } - - std::unique_ptr computed_result = - embedded_evaluator - .Evaluate>(*computation, - arg_literals) - .ConsumeValueOrDie(); - // Clear visit states so that the we can use the evaluate again on - // the same computation. - embedded_evaluator.ResetVisitStates(); - - return computed_result->Get({}); - })); - return std::move(result); - } - - Status HandleMap(HloInstruction* map) override { - switch (map->operand(0)->shape().element_type()) { - case PRED: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case U8: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case U32: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case U64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case S8: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case S32: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case S64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case F16: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], - MapImpl(map)); - break; - } - case F32: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case F64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case C64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - default: - LOG(FATAL) << "HandleMap: unhandled primitive type for " - "input operand: " - << PrimitiveType_Name( - map->operand(0)->shape().element_type()); - } - - return Status::OK(); - } - - Status HandleReduce(HloInstruction* reduce) override { - auto arg = reduce->operand(0); - auto init_value = reduce->operand(1); - ArraySlice dimensions(reduce->dimensions()); - HloComputation* function = reduce->to_apply(); - TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) == - ShapeUtil::Rank(arg->shape()) - dimensions.size()); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferReduceShape( - /*arg=*/arg->shape(), - /*init_value=*/init_value->shape(), - /*dimensions_to_reduce=*/dimensions, - /*to_apply=*/function->ComputeProgramShape())); - TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape()) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); - VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString(); - const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value); - VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString(); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); - auto init_scalar = init_literal.Get({}); - - auto result = Literal::CreateFromShape(reduce->shape()); - - const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions()); - std::vector arg_dim_steps(arg_dimensions.size()); - std::vector arg_dim_counts(arg_dimensions.size()); - for (const int64 dim : dimensions) { - arg_dim_steps[dim] = 1; - arg_dim_counts[dim] = arg_dimensions[dim]; - } - - // Map each dimension in the result to a dimension in arg that isn't - // being reduced. - std::vector result_to_arg_index; - for (int64 i = 0; i < arg_dimensions.size(); ++i) { - if (arg_dim_steps[i] == 0) { - result_to_arg_index.push_back(i); - } - } - - HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice multi_index) { - ReturnT result_val = init_scalar; - - std::vector base(arg_dimensions.size()); - for (int64 i = 0; i < multi_index.size(); ++i) { - base[result_to_arg_index[i]] = multi_index[i]; - } - - // When the reduction is addition of floats, accumulate in a double - // for better precision. Also, avoid creating Literals for the - // intermediate results; it's much faster. - if (ShapeUtil::ElementIsFloating(init_literal.shape()) && - IsScalarAdd(function)) { - double computed_result = 0; - auto func = [&](ArraySlice input_index) { - computed_result += arg_literal.Get(input_index); - return true; - }; - ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, - arg_dim_steps, func); - return static_cast(computed_result); - } - auto func = [&](ArraySlice input_index) { - auto curr_val = arg_literal.Get(input_index); - - // Evaluate computation with specified literal operands. - auto curr_val_literal = Literal::CreateR0(curr_val); - auto result_val_literal = Literal::CreateR0(result_val); - std::vector args = {result_val_literal.get(), - curr_val_literal.get()}; - - std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) - .ConsumeValueOrDie(); - // Clear visit states so that we can use the evaluator again on - // the same computation. - embedded_evaluator.ResetVisitStates(); - // Assign computed result to result_val. - result_val = computed_result->Get({}); - return true; - }; - // Computes one element of the result, reducing all dimensions that - // contribute to that element. - ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, - arg_dim_steps, func); - return result_val; - })); - - parent_->evaluated_[reduce] = std::move(result); - return Status::OK(); - } - - bool IsScalarAdd(HloComputation* computation) { - HloInstruction* instruction = computation->root_instruction(); - if (instruction->opcode() == HloOpcode::kAdd && - computation->num_parameters() == 2) { - const HloInstruction* lhs = instruction->operand(0); - const HloInstruction* rhs = instruction->operand(1); - return lhs->opcode() == HloOpcode::kParameter && - ShapeUtil::IsScalar(lhs->shape()) && - rhs->opcode() == HloOpcode::kParameter && - ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs; - } - return false; - } - - Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { - auto operand = select_and_scatter->operand(0); - auto source = select_and_scatter->operand(1); - const Window& window = select_and_scatter->window(); - - const Literal& init_literal = - parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2)); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); - auto init_scalar = init_literal.Get({}); - - auto result = Literal::CreateFromShape(select_and_scatter->shape()); - - // Initialize result array with the init value. - TF_RETURN_IF_ERROR(result->Populate( - [&](ArraySlice output_index) { return init_scalar; })); - - std::vector window_dimension_sizes; - for (const auto& window_dimension : window.dimensions()) { - window_dimension_sizes.push_back(window_dimension.size()); - } - const Shape window_shape = ShapeUtil::MakeShape( - operand->shape().element_type(), window_dimension_sizes); - - HloComputation* select = select_and_scatter->select(); - HloComputation* scatter = select_and_scatter->scatter(); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source); - - int64 rank = ShapeUtil::Rank(operand_literal.shape()); - - HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - DimensionVector source_index(rank); - - std::fill(source_index.begin(), source_index.end(), 0); - do { - // For each element in `source`, we place a window in `operand`. For each - // window placement, we iterate inside the window twice: - // - // 1. Find the selected index by applying `select` function to all - // elements. E.g., If the `select` function is GreaterEqual, the first - // iteration through the window finds the biggest value and returns its - // index. - // - // 2. Using the selected index, scatter value from `source` to result. We - // do this by iterating through the window, and compare each index with - // the selected index. - optional selected_val; - optional> selected_index; - - IterateThroughWindow( - window_shape, window, operand_literal.shape(), source_index, - [&](const std::vector& operand_index) { - auto curr_val = operand_literal.Get(operand_index); - if (!selected_val) { - selected_val = curr_val; - selected_index = operand_index; - } - const auto curr_val_literal = Literal::CreateR0(curr_val); - const auto selected_val_literal = - Literal::CreateR0(*selected_val); - - const std::vector args = { - selected_val_literal.get(), curr_val_literal.get()}; - std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*select, args) - .ConsumeValueOrDie(); - bool selected = !computed_result->Get({}); - if (selected) { - selected_val = curr_val; - selected_index = operand_index; - } - embedded_evaluator.ResetVisitStates(); - }); - - IterateThroughWindow( - window_shape, window, operand_literal.shape(), source_index, - [&](const std::vector& operand_index) { - if (std::equal(operand_index.begin(), operand_index.end(), - selected_index->begin())) { - auto source = source_literal.Get(source_index); - auto scattered = result->Get(operand_index); - const auto source_literal = Literal::CreateR0(source); - const auto scattered_literal = - Literal::CreateR0(scattered); - - const std::vector args = { - source_literal.get(), scattered_literal.get()}; - std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*scatter, args) - .ConsumeValueOrDie(); - result->Set(operand_index, computed_result->Get({})); - // Clear visit states so that the we can use the evaluator again - // on the same computation. - embedded_evaluator.ResetVisitStates(); - } - }); - } while (IndexUtil::BumpIndices(source->shape(), &source_index)); - - parent_->evaluated_[select_and_scatter] = std::move(result); - return Status::OK(); - } - - Status HandleReduceWindow(HloInstruction* reduce_window) override { - auto operand = reduce_window->operand(0); - const Window& window = reduce_window->window(); - HloComputation* function = reduce_window->to_apply(); - TF_ASSIGN_OR_RETURN( - auto inferred_return_shape, - ShapeInference::InferReduceWindowShape( - /*operand_shape=*/reduce_window->operand(0)->shape(), - /*init_value=*/reduce_window->operand(1)->shape(), window, - /*to_apply_shape=*/function->ComputeProgramShape())); - TF_RET_CHECK( - ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) - << "return shape is set to: " - << ShapeUtil::HumanStringWithLayout(reduce_window->shape()) - << "but is inferred to be: " - << ShapeUtil::HumanStringWithLayout(inferred_return_shape); - - const Literal& operand_literal = - parent_->GetEvaluatedLiteralFor(reduce_window->operand(0)); - VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString(); - const Literal& init_literal = - parent_->GetEvaluatedLiteralFor(reduce_window->operand(1)); - VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString(); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); - auto init_scalar = init_literal.Get({}); - - auto result = Literal::CreateFromShape(reduce_window->shape()); - - // Creates a Shape object from window, for iteration below. - std::vector window_dimension_sizes; - for (const auto& window_dimension : window.dimensions()) { - window_dimension_sizes.push_back(window_dimension.size()); - } - const Shape window_shape = ShapeUtil::MakeShape( - operand->shape().element_type(), window_dimension_sizes); - - DimensionVector window_index(window.dimensions_size()); - DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); - - HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice output_index) { - ReturnT result_val = init_scalar; - - std::fill(window_index.begin(), window_index.end(), 0); - std::fill(operand_index.begin(), operand_index.end(), 0); - - IterateThroughWindow( - window_shape, window, operand_literal.shape(), output_index, - [&](const std::vector& operand_index) { - auto curr_val = operand_literal.Get(operand_index); - - // Evaluate computation with specified literal operands. - const auto curr_val_literal = - Literal::CreateR0(curr_val); - const auto result_val_literal = - Literal::CreateR0(result_val); - const std::vector args = { - result_val_literal.get(), curr_val_literal.get()}; - std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) - .ConsumeValueOrDie(); - - // Clear visit states so that the we can use the evaluate again - // on the same computation. - embedded_evaluator.ResetVisitStates(); - - result_val = computed_result->Get({}); - }); - - return result_val; - })); - - parent_->evaluated_[reduce_window] = std::move(result); - return Status::OK(); - } - - Status HandleSlice(HloInstruction* slice) override { - auto operand = slice->operand(0); - const Shape& shape = slice->shape(); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferSliceShape( - operand->shape(), slice->slice_starts(), - slice->slice_limits(), slice->slice_strides())); - TF_RET_CHECK(ShapeUtil::Compatible(shape, inferred_return_shape)) - << "return shape set to: " << ShapeUtil::HumanString(shape) - << " but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const int64 rank = ShapeUtil::Rank(operand->shape()); - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto func = [&](ArraySlice out_index) { - DimensionVector operand_index(rank); - for (int64 i = 0; i < rank; ++i) { - operand_index[i] = - slice->slice_starts(i) + out_index[i] * slice->slice_strides(i); - } - return operand_literal.Get(operand_index); - }; - - auto result = Literal::CreateFromDimensions( - shape.element_type(), AsInt64Slice(shape.dimensions())); - TF_RETURN_IF_ERROR(result->Populate(func)); - parent_->evaluated_[slice] = std::move(result); - return Status::OK(); - } - - // Enable CLZ only for int32 and uint32. - template < - typename NativeT, - typename std::enable_if< - (std::is_floating_point::value || - std::is_integral::value || is_complex_t::value) && - !(std::is_same::value || - std::is_same::value)>::type* = nullptr> - Status HandleClz(HloInstruction* clz) { - return InvalidArgument("Unsupported type for Clz"); - } - - template ::value || - std::is_same::value>::type* = nullptr> - Status HandleClz(HloInstruction* clz) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz], - ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) { - return 31 - tensorflow::Log2Floor(elem_operand); - })); - return Status::OK(); - } - - Status HandleClz(HloInstruction* clz) override { - return HandleClz(clz); - } - - template ::value>::type* = nullptr> - Status HandleSin(HloInstruction* sin) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin], - ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) { - return std::sin(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value || - is_complex_t::value>::type* = nullptr> - Status HandleSin(HloInstruction* sin) { - return InvalidArgument("Unsupported type for Sin"); - } - - Status HandleSin(HloInstruction* sin) override { - return HandleSin(sin); - } - - template ::value>::type* = nullptr> - Status HandleCos(HloInstruction* cos) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos], - ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) { - return std::cos(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value || - is_complex_t::value>::type* = nullptr> - Status HandleCos(HloInstruction* cos) { - return InvalidArgument("Unsupported type for Cos"); - } - - Status HandleCos(HloInstruction* cos) override { - return HandleCos(cos); - } - - template ::value>::type* = nullptr> - Status HandleReducePrecision(HloInstruction* reduce_precision) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[reduce_precision], - ElementWiseUnaryOp(reduce_precision, [reduce_precision]( - ElementwiseT elem) { - uint32_t value_as_int = tensorflow::bit_cast(elem); - const uint32_t mantissa_bits = reduce_precision->mantissa_bits(); - const uint32_t exponent_bits = reduce_precision->exponent_bits(); - - // Code is based on the CPU/GPU implementation in LLVM-emitting code. - // - // Bits in float type: - // mantissa : bits [0:22] - // exponent : bits [23:30] - // sign : bits [31] - if (mantissa_bits < 23) { - const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); - - // Compute rounding bias for round-to-nearest with ties to even. - // This is equal to a base value of 0111... plus one bit if the last - // remaining mantissa bit is 1. - const uint32_t base_rounding_bias = - (last_mantissa_bit_mask >> 1) - 1; - const uint32_t x_last_mantissa_bit = - (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits); - const uint32_t x_rounding_bias = - x_last_mantissa_bit + base_rounding_bias; - - // Add rounding bias, and mask out truncated bits. Note that the - // case where adding the rounding bias overflows into the exponent - // bits is correct; the non-masked mantissa bits will all be zero, - // and the exponent will be incremented by one. - const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); - value_as_int = value_as_int + x_rounding_bias; - value_as_int = value_as_int & truncation_mask; - } - if (exponent_bits < 8) { - // Masks for f32 values. - const uint32_t f32_sign_bit_mask = 1u << 31; - const uint32_t f32_exp_bits_mask = 0xffu << 23; - - // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the - // most- significant bit -- is equal to 1.0f for all exponent sizes. - // Adding 2^(n-1)-1 to this gives us the highest non-infinite - // exponent for a bit- size of n, and subtracting 2^(n-1)-1 from - // this gives us the lowest' exponent (corresponding to 0.0f). - // - // Thus, the f32 exponent corresponding to the highest non-infinite - // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 - // exponent corresponding to the lowest exponent for a bit size of n - // is (2^7-1) - 2^(n-1)-1. - // - // Note that we have already checked that exponents_bits >= 1. - const uint32_t f32_exponent_bias = (1 << 7) - 1; - const uint32_t reduced_exponent_bias = - (1 << (exponent_bits - 1)) - 1; - const uint32_t reduced_max_exponent = - f32_exponent_bias + reduced_exponent_bias; - const uint32_t reduced_min_exponent = - f32_exponent_bias - reduced_exponent_bias; - - // Do we overflow or underflow? - const uint32_t x_exponent = value_as_int & f32_exp_bits_mask; - const bool x_overflows = x_exponent > (reduced_max_exponent << 23); - const bool x_underflows = - x_exponent <= (reduced_min_exponent << 23); - - // Compute appropriately-signed values of zero and infinity. - const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask; - const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask; - - // Force to zero or infinity if overflow or underflow. (Note that - // this truncates all denormal values to zero, rather than rounding - // them.) - value_as_int = x_overflows ? x_signed_inf : value_as_int; - value_as_int = x_underflows ? x_signed_zero : value_as_int; - } - - float reduced_result = tensorflow::bit_cast(value_as_int); - if (std::isnan(elem)) { - reduced_result = mantissa_bits > 0 - ? elem - : std::numeric_limits::infinity(); - } - return reduced_result; - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Double not supported for reduce precision"); - } - - template < - typename NativeT, - typename std::enable_if::value || - is_complex_t::value>::type* = nullptr> - Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Unsupported type for reduce precision"); - } - - Status HandleReducePrecision(HloInstruction* reduce_precision) override { - return HandleReducePrecision(reduce_precision); - } - - private: - template - StatusOr> DynamicSlice( - const Literal& operand_literal, const Literal& start_indices_literal, - const Shape& result_shape) { - auto start_indices_typed = start_indices_literal.data(); - std::vector start(start_indices_typed.begin(), - start_indices_typed.end()); - - std::vector operand_indices(start.size()); - - auto result = Literal::CreateFromShape(result_shape); - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice multi_index) { - for (int64 i = 0; i < operand_indices.size(); ++i) { - CHECK_GE(multi_index[i] + start[i], 0); - // Mod is only used here to be consistent with the existing - // backends' behavior. - operand_indices[i] = (multi_index[i] + start[i]) % - operand_literal.shape().dimensions(i); - } - - auto result = operand_literal.Get(operand_indices); - return result; - })); - - return std::move(result); - } - - template - StatusOr> DynamicUpdateSlice( - const Literal& operand_literal, const Literal& update_literal, - const Literal& start_indices_literal) { - auto result = operand_literal.CloneToUnique(); - auto start_indices_typed = start_indices_literal.data(); - const auto rank = ShapeUtil::Rank(result->shape()); - std::vector start(rank, 0); - for (int64 i = 0; i < rank; ++i) { - // All other implementations currently wrap-around the index, so this - // should do so as well. - start[i] = (start_indices_typed[i] % result->shape().dimensions(i)); - start[i] += (start[i] < 0) * result->shape().dimensions(i); - } - std::vector result_index(rank, 0); - - auto func = [&](ArraySlice update_index) { - std::transform(update_index.begin(), update_index.end(), start.begin(), - result_index.begin(), std::plus()); - // Same as above, wrap-around only to match other implementations' - // semantics. - std::transform(result_index.begin(), result_index.end(), - result->shape().dimensions().begin(), result_index.begin(), - std::modulus()); - result->Set(result_index, - update_literal.Get(update_index)); - return true; - }; - - std::vector base(update_literal.shape().dimensions_size(), 0); - std::vector step(update_literal.shape().dimensions_size(), 1); - ShapeUtil::ForEachIndex(update_literal.shape(), base, - AsInt64Slice(update_literal.shape().dimensions()), - step, func); - - return std::move(result); - } - - StatusOr> ElementWiseUnaryOp( - HloInstruction* instruction, - const std::function& unary_op) { - const Literal& operand_literal = - parent_->GetEvaluatedLiteralFor(instruction->operand(0)); - TF_ASSIGN_OR_RETURN( - auto result_literal, - (ElementWiseUnaryOpImpl( - instruction, ConvertUnaryFunction(unary_op), operand_literal))); - - return std::move(result_literal); - } - - StatusOr> ElementWiseBinaryOp( - HloInstruction* instruction, - const std::function& - binary_op) { - const auto shape = instruction->shape(); - const auto* lhs = instruction->operand(0); - const auto* rhs = instruction->operand(1); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast - // is removed. - if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s: ", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str()); - } - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - - auto result = Literal::CreateFromShape(shape); - - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice multi_index) { - return ConvertBinaryFunction(binary_op)( - lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); - })); - return std::move(result); - } - - template - StatusOr> ElementwiseTernaryOp( - HloInstruction* instruction, - const std::function& ternary_op) { - const auto shape = instruction->shape(); - const auto* lhs = instruction->operand(0); - const auto* rhs = instruction->operand(1); - const auto* ehs = instruction->operand(2); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit - // broadcast is removed. - if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && - ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s vs %s: ", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str(), - ShapeUtil::HumanString(ehs->shape()).c_str()); - } - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - - auto result = Literal::CreateFromShape(shape); - - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice multi_index) { - return ternary_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index), - ehs_literal.Get(multi_index)); - })); - - return std::move(result); - } - - template - static bool IsShiftOutOfBounds(NativeT rhs) { - typedef typename std::make_unsigned::type UnsignedT; - UnsignedT lhs_size_unsigned = sizeof(NativeT) * CHAR_BIT; - UnsignedT rhs_unsigned = static_cast(rhs); - return rhs_unsigned >= lhs_size_unsigned; - } - - HloEvaluator* parent_; -}; // class HloEvaluator::TypedVisitor HloEvaluator::HloEvaluator(int64 max_loop_iterations) : max_loop_iterations_(max_loop_iterations) { - typed_visitors_[PRED] = MakeUnique>(this); - typed_visitors_[U8] = MakeUnique>(this); + typed_visitors_[PRED] = MakeUnique>(this); + typed_visitors_[U8] = MakeUnique>(this); typed_visitors_[U16] = MakeUnique([](HloInstruction*) { return Unimplemented( - "HloEvaluator::TypedVisitor: unhandled primitive type: U16."); + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "U16."); }); - typed_visitors_[U32] = MakeUnique>(this); - typed_visitors_[U64] = MakeUnique>(this); - typed_visitors_[S8] = MakeUnique>(this); + typed_visitors_[U32] = MakeUnique>(this); + typed_visitors_[U64] = MakeUnique>(this); + typed_visitors_[S8] = MakeUnique>(this); typed_visitors_[S16] = MakeUnique([](HloInstruction*) { return Unimplemented( - "HloEvaluator::TypedVisitor: unhandled primitive type: S16."); + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "S16."); }); - typed_visitors_[S32] = MakeUnique>(this); - typed_visitors_[S64] = MakeUnique>(this); - typed_visitors_[F16] = MakeUnique>(this); - typed_visitors_[F32] = MakeUnique>(this); - typed_visitors_[F64] = MakeUnique>(this); - typed_visitors_[C64] = MakeUnique>(this); + typed_visitors_[S32] = MakeUnique>(this); + typed_visitors_[S64] = MakeUnique>(this); + typed_visitors_[F16] = + MakeUnique>(this); + typed_visitors_[F32] = MakeUnique>(this); + typed_visitors_[F64] = MakeUnique>(this); + typed_visitors_[C64] = MakeUnique>(this); // Most of the evaluator computations we use don't support BF16 (e.g., // std::ceil, std::tanh). To make evaluator work with BF16, we set all // elementwise computations to be done in F32 and do BF16<->F32 conversion // around the input and the output of the computations. - typed_visitors_[BF16] = MakeUnique>(this); + typed_visitors_[BF16] = + MakeUnique>(this); typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { return Unimplemented( - "HloEvaluator::TypedVistor: unhandled primitive type: TUPLE."); + "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE."); }); typed_visitors_[OPAQUE] = MakeUnique([](HloInstruction*) { return Unimplemented( - "HloEvaluator::TypedVisitor: unhandled primitive type: OPAQUE."); + "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); }); } @@ -3009,8 +951,8 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) { auto* true_computation = conditional->true_computation(); auto* false_computation = conditional->false_computation(); - auto result = Literal::CreateFromShape(conditional->shape()); HloEvaluator embedded_evaluator; + std::unique_ptr result; if (pred.Get({})) { result = embedded_evaluator .Evaluate(*true_computation, @@ -3034,7 +976,7 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) { // If predicate is of scalar type, no element-wise selection would be needed. // This would also handle output array of tuple types as the DefaultAction - // would go through the TypedVisitor which doesn't handle tuples. + // would go through the HloEvaluatorTypedVisitor which doesn't handle tuples. if (ShapeUtil::IsScalar(pred.shape())) { if (pred.Get({})) { evaluated_[select] = on_true.CloneToUnique(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index c0dcee0c3e382f74de72a2b89f39e06f042e2b80..566d53a41427119ea3d429a60a4430068bc953b1 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -109,19 +110,16 @@ class HloEvaluator : public DfsHloVisitorWithDefault { substitutions); protected: - // Templated DfsHloVisitor. Typically ReturnT here indicates the resulting - // literal type of each evaluated Handle* method of a TypedVisitor. - // There are however a few notable exceptions to this rule, notably: - // - HandleCompare and HandleIsFinite: where the resulting literal type is - // always boolean. - // These operations are handled outside of the parent HloEvaluator handlers - // instead of from within TypedVisitor. + // Make HloEvaluatorTypedVisitor a friend because it is logically part of this + // class. // - // Type params: - // - ReturnT: The type of input and output of each operation. - // - ElementwiseT: The type in which internal computation are done. - template - class TypedVisitor; + // A straightforward implementation would be to make it a nested class + // declared and defined in hlo_evaluator.cc. Instead HloEvaluatorTypedVisitor + // lives as a separate class with its own header because its template gets + // instantiated many times and we want to use extern templates to shard out + // the compilation of those instantiations across multiple cc files. + template + friend class HloEvaluatorTypedVisitor; // Wraps around instruction handling to infer types before dispatching to // the corresponding typed Visitor. @@ -168,7 +166,6 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override; - private: // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. @@ -183,14 +180,6 @@ class HloEvaluator : public DfsHloVisitorWithDefault { return *(it->second); } - // Map from a primitive type to its associated (templated) DfsHloVisitor. - // Note: the hash function here is only needed because current gcc std::hash - // does not specialize for enum types. This should however be fixed in the - // future: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60970#c5 - tensorflow::gtl::FlatMap, - std::hash> - typed_visitors_; - // Tracks the HLO instruction and its evaluated literal result. // TODO(b/35950897): have better memory management here to free instructions // that are no longer a parent for any other subsequent instruction in @@ -199,6 +188,41 @@ class HloEvaluator : public DfsHloVisitorWithDefault { tensorflow::gtl::FlatMap> evaluated_; + private: + template + static StatusOr> ElementWiseUnaryOpImpl( + HloInstruction* instruction, + const std::function& unary_op, + const Literal& operand_literal) { + const auto shape = instruction->shape(); + const auto* operand = instruction->operand(0); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!ShapeUtil::SameDimensions(shape, operand->shape())) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(operand->shape()).c_str()); + } + + auto result = MakeUnique(shape); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return unary_op(operand_literal.Get(multi_index)); + })); + return std::move(result); + } + + // Map from a primitive type to its associated (templated) DfsHloVisitor. + // Note: the hash function here is only needed because current gcc std::hash + // does not specialize for enum types. This should however be fixed in the + // future: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60970#c5 + tensorflow::gtl::FlatMap, + std::hash> + typed_visitors_; + // Caches pointers to input literals, assuming they are in post-order. // Literals are not owned by this class, and they must outlive the lifetime of // each invocation to the Evaluate* method. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index cc16446778cbeac5ec4bed110adc9be8620084fe..ae5b5e0412ef99db9b72d645a954759ca0b9eb8b 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -82,9 +82,9 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, auto element_type = expected->shape().element_type(); if (element_type == F32 || element_type == F64) { ErrorSpec error(aabs); - LiteralTestUtil::ExpectNear(*expected, *result, error); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error)); } else { - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } } @@ -100,7 +100,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, std::unique_ptr result = Evaluate(); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } bool use_bfloat16_; @@ -129,7 +129,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) { auto expected = Literal::CreateR2({{0, 4}, {2, 4}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { @@ -150,7 +150,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { auto expected = Literal::CreateR2({{0, 0}, {1, 1}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs select @@ -175,7 +175,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { auto expected = Literal::CreateR2({{2, 5}, {0, 4}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs @@ -307,7 +307,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { auto expected = Literal::CreateR2({{4, -16}, {-196, 12}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies Reshape operation is correctly evaluated. @@ -315,7 +315,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { HloComputation::Builder b(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - LiteralTestUtil::CreateRandomLiteral( + Literal::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); auto literal_clone = literal->CloneToUnique(); HloInstruction* literal_instruction = @@ -351,7 +351,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { std::unique_ptr result = Evaluate({}); - LiteralTestUtil::ExpectEqual(*result, *output_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); } TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { @@ -370,7 +370,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { std::unique_ptr result = Evaluate({}); - LiteralTestUtil::ExpectEqual(*result, *output_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); } TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { @@ -392,7 +392,7 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { auto expected = Literal::CreateR2({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { @@ -413,7 +413,7 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR1({100, 200}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { @@ -432,7 +432,7 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { std::unique_ptr result = Evaluate(); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { @@ -452,7 +452,7 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { std::unique_ptr result = Evaluate(); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } PaddingConfig CreatePaddingConfig( @@ -490,7 +490,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { auto expected = Literal::CreateR2( {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { @@ -525,7 +525,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { auto expected = Literal::CreateR4FromArray4D(*expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, NegativePadding2D) { @@ -567,7 +567,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { (*expected_array)(0, 4) = 2.718f; auto expected = Literal::CreateR2FromArray2D(*expected_array); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(0x1.0P-5)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0x1.0P-5))); } TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { @@ -606,7 +606,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { auto expected_array = MakeUnique>(0, 9); auto expected = Literal::CreateR2FromArray2D(*expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank1) { @@ -651,7 +651,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { // clang-format on auto expected = Literal::CreateR2FromArray2D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DotRank1AndRank2) { @@ -688,7 +688,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { auto expected = Literal::CreateR1({22.f, 28.f}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank2) { @@ -737,7 +737,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { }); auto expected = Literal::CreateR2FromArray2D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, SimpleConv1D) { @@ -785,7 +785,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { Array3D expected_array = {{{11.f, 18.f, 9.f}}}; auto expected = Literal::CreateR3FromArray3D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { @@ -847,7 +847,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { // clang-format on auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { @@ -927,7 +927,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { auto expected = Literal::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { @@ -1004,7 +1004,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { auto expected = Literal::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { @@ -1067,7 +1067,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { })); auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { @@ -1131,7 +1131,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { })); auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, @@ -1203,7 +1203,7 @@ TEST_P(HloEvaluatorTest, })); auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; @@ -1319,7 +1319,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { auto expected = Literal::CreateR1({6, 18}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ReduceWindowMax) { @@ -1370,7 +1370,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{6, 7}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd) { @@ -1427,7 +1427,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{1, 3, 5}, {5, 11, 13}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { @@ -1490,7 +1490,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { std::vector output_dims = {4, 3, 3, 3, 4, 4}; std::unique_ptr result_literal = Literal::CreateFullWithDescendingLayout(output_dims, 8.0f); - LiteralTestUtil::ExpectEqual(*result_literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result)); } TEST_P(HloEvaluatorTest, StridedSlice) { @@ -1523,7 +1523,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { {19}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DynamicSlice) { @@ -1556,7 +1556,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { {6, 7, 8}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies that the HloEvaluator's implementation goes along with existing @@ -1591,7 +1591,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { {6, 7, 8}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { @@ -1627,7 +1627,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { {5, -6, -7}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, SetAndGetTuples) { @@ -1662,7 +1662,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { {5, 6, 7}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { @@ -1703,7 +1703,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { result_inner_literal.get(), }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Reverse) { @@ -1756,7 +1756,7 @@ TEST_P(HloEvaluatorTest, Reverse) { }); // clang-format on - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { @@ -1776,8 +1776,8 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { add, {{param0, Literal::CreateR1({1, 2, 3, 4}).get()}, {square, Literal::CreateR1({10, 20, 30, 40}).get()}}); TF_ASSERT_OK(result.status()); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({11, 22, 33, 44}), - *result.ValueOrDie()); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); } // Check that EvaluateWithSubstitutions works if one of the operands to the op @@ -1800,8 +1800,8 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { auto result = evaluator.EvaluateWithSubstitutions( add, {{square, Literal::CreateR1({10, 20, 30, 40}).get()}}); TF_ASSERT_OK(result.status()); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({11, 22, 33, 44}), - *result.ValueOrDie()); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { @@ -1823,9 +1823,9 @@ ENTRY main { std::unique_ptr operand = Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { @@ -1847,9 +1847,9 @@ ENTRY main { std::unique_ptr operand = Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 3}, {4, 6}, {7, 9}}), - *Evaluate({operand.get(), gather_indices.get()})); + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { @@ -1872,10 +1872,10 @@ ENTRY main { Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR2({{0, 2}, {2, 1}}); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR3( {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), - *Evaluate({operand.get(), gather_indices.get()})); + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { @@ -1900,9 +1900,9 @@ ENTRY main { {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = Literal::CreateR2({{0, 0}, {1, 0}}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{-1, 1}, {-4, 4}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-1, 1}, {-4, 4}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, @@ -1928,9 +1928,9 @@ ENTRY main { {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = Literal::CreateR2({{0, 0}, {1, 0}}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{-2, 2}, {-1, 1}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-2, 2}, {-1, 1}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { @@ -1952,9 +1952,9 @@ ENTRY main { std::unique_ptr operand = Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{5}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{5}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { @@ -1977,9 +1977,9 @@ ENTRY main { Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR2({{2, 1}, {1, 1}}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR3({{{8}}, {{5}}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR3({{{8}}, {{5}}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { @@ -2000,9 +2000,34 @@ ENTRY main { ParseAndVerifyModule(hlo_text); std::unique_ptr operand = Literal::CreateR2({{}, {}, {}}); std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{}, {}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{}, {}}), + *Evaluate({operand.get(), gather_indices.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { + const string hlo_text = R"( +HloModule GatherXd + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[2,2,1] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1} +} +)"; + ParseAndVerifyModule(hlo_text); + + std::unique_ptr operand = Literal::CreateR1({0, 1, 2}); + std::unique_ptr gather_indices = + Literal::CreateR3({{{0}, {1}}, {{2}, {1}}}); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{0, 1}, {2, 1}}), + *Evaluate({operand.get(), gather_indices.get()}))); } // Verifies that HloEvaluator evaluates a HLO instruction that performs diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h new file mode 100644 index 0000000000000000000000000000000000000000..024e8751f79b8b73cb868f6cbd4603f3e94ca7ea --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -0,0 +1,2170 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/gtl/optional.h" + +namespace xla { + +// TODO(b/79274244): We'd like these type traits to live inside of +// HloEvaluatorTypedVisitor so they don't pollute namespace xla, but that +// crashes clang in the frontend. +// +// Anyway this is relatively safe as-is because hlo_evaluator_typed_visitor.h is +// a "private" header that's not exposed outside of hlo_evaluator.cc. +template +using is_complex_t = std::is_same; +template +using is_complex64_t = std::is_same; + +// Templated DfsHloVisitor for use by HloEvaluator. +// +// Typically ReturnT here indicates the resulting literal type of each evaluated +// Handle* method of a TypedVisitor. There are however a few notable exceptions +// to this rule, notably: +// - HandleCompare and HandleIsFinite: where the resulting literal type is +// always boolean. +// These operations are handled outside of the parent HloEvaluator handlers +// instead of from within TypedVisitor. +// +// Type params: +// - ReturnT: The type of input and output of each operation. +// - ElementwiseT: The type in which internal computation are done. +// +// This a logically a private part of HloEvaluator. It lives in this header +// file rather than in hlo_evaluator.cc because we use extern templates and a +// bunch of independent cc files to speed up compiling the many instantiations +// of this class. +template +class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { + public: + explicit HloEvaluatorTypedVisitor(HloEvaluator* p) : parent_(p) {} + + // The following higher-order functions convert a function with ElementwiseT + // to a function with ReturnT. + std::function ConvertUnaryFunction( + const std::function& unary_op) { + return [&unary_op](ReturnT arg) { + return static_cast(unary_op(static_cast(arg))); + }; + } + std::function ConvertBinaryFunction( + const std::function& + binary_op) { + return [&binary_op](ReturnT arg1, ReturnT arg2) { + return static_cast(binary_op(static_cast(arg1), + static_cast(arg2))); + }; + } + std::function ConvertTernaryFunction( + const std::function& ternary_op) { + return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) { + return static_cast(ternary_op(static_cast(arg1), + static_cast(arg2), + static_cast(arg3))); + }; + } + + Status DefaultAction(HloInstruction* hlo_instruction) override { + return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", + HloOpcodeString(hlo_instruction->opcode()).c_str()); + } + + // TODO(b/35950897): many of the stl functions used in the handlers are not + // overloaded for every XLA primitive type. + + template ::value>::type* = + nullptr> + Status HandleAbs(HloInstruction* abs) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { + return elem_operand; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAbs(HloInstruction* abs) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { + return std::abs(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAbs(HloInstruction* abs) { + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(abs->operand(0)); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[abs], + (HloEvaluator::ElementWiseUnaryOpImpl( + abs, [](NativeT elem_operand) { return std::abs(elem_operand); }, + operand_literal))); + + return Status::OK(); + } + + Status HandleAbs(HloInstruction* abs) override { + // If the operand is of C64 type, the return type of abs will be F32. + // However, ElementwiseT would still be the return type, F32, and thus + // specifying the ElementwiseT explicitly as C64 is needed below. + if (abs->operand(0)->shape().element_type() == C64) { + return HandleAbs(abs); + } + return HandleAbs(abs); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRound(HloInstruction* round) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[round], + ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) { + return std::round(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRound(HloInstruction* round) { + return InvalidArgument("Unsupported type for Round"); + } + + Status HandleRound(HloInstruction* round) override { + return HandleRound(round); + } + + Status HandleBroadcast(HloInstruction* broadcast) override { + const Literal& operand_to_broadcast = + parent_->GetEvaluatedLiteralFor(broadcast->operand(0)); + std::vector broadcast_indices( + ShapeUtil::Rank(broadcast->operand(0)->shape()), 0); + + TF_RET_CHECK(broadcast->dimensions().size() == + ShapeUtil::Rank(operand_to_broadcast.shape())) + << "broadcast dimensions is of size: " << broadcast->dimensions().size() + << " and rank of operand_to_broadcast is: " + << ShapeUtil::Rank(operand_to_broadcast.shape()); + // Checks that operand's dimensions are the same as the broadcast's + // dimensions along the dimensions to be broadcasted. + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == + operand_to_broadcast.shape().dimensions(i)); + } + + auto output = MakeUnique(broadcast->shape()); + TF_RETURN_IF_ERROR(output->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; + } + return operand_to_broadcast.Get(broadcast_indices); + })); + parent_->evaluated_[broadcast] = std::move(output); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleCeil(HloInstruction* ceil) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], + ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) { + return std::ceil(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleCeil(HloInstruction* ceil) { + return InvalidArgument("Unsupported type for Ceil"); + } + + Status HandleCeil(HloInstruction* ceil) override { + return HandleCeil(ceil); + } + + Status HandleConvert(HloInstruction* convert) override { + const HloInstruction* operand = convert->operand(0); + TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); + TF_ASSIGN_OR_RETURN(std::unique_ptr result, + parent_->GetEvaluatedLiteralFor(operand).Convert( + convert->shape().element_type())); + + if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { + parent_->evaluated_[convert] = std::move(result); + } else { + parent_->evaluated_[convert] = + result->Relayout(convert->shape().layout()); + } + return Status::OK(); + } + + Status HandleBitcastConvert(HloInstruction* convert) override { + const HloInstruction* operand = convert->operand(0); + TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); + TF_ASSIGN_OR_RETURN(std::unique_ptr result, + parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( + convert->shape().element_type())); + + if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { + parent_->evaluated_[convert] = std::move(result); + } else { + parent_->evaluated_[convert] = + result->Relayout(convert->shape().layout()); + } + return Status::OK(); + } + + Status HandleExp(HloInstruction* exp) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], + ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) { + return std::exp(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleExpm1(HloInstruction* expm1) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[expm1], + ElementWiseUnaryOp(expm1, [](ElementwiseT elem_operand) { + return std::expm1(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleExpm1(HloInstruction* floor) { + return InvalidArgument("Unsupported type for Expm1"); + } + + Status HandleExpm1(HloInstruction* floor) override { + return HandleExpm1(floor); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleFloor(HloInstruction* floor) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[floor], + ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) { + return std::floor(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleFloor(HloInstruction* floor) { + return InvalidArgument("Unsupported type for Floor"); + } + + Status HandleFloor(HloInstruction* floor) override { + return HandleFloor(floor); + } + + Status HandleLog(HloInstruction* log) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], + ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { + return std::log(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleLog1p(HloInstruction* expm1) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[expm1], + ElementWiseUnaryOp(expm1, [](ElementwiseT elem_operand) { + return std::log1p(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleLog1p(HloInstruction* floor) { + return InvalidArgument("Unsupported type for Log1p"); + } + + Status HandleLog1p(HloInstruction* floor) override { + return HandleLog1p(floor); + } + + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { + return ~elem_operand; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { + return !elem_operand; + })); + return Status::OK(); + } + + template ::value>::type* = + nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { + return !elem_operand; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + return InvalidArgument("Unsupported type for Not"); + } + + Status HandleNot(HloInstruction* not_) override { + return HandleNot(not_); + } + + template ::value && + !std::is_floating_point::value>::type* = nullptr> + Status HandleNegate(HloInstruction* negate) { + using type = typename std::make_unsigned::type; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[negate], + ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) { + return NativeT(-type(elem_operand)); + })); + return Status::OK(); + } + + template ::value || + std::is_floating_point::value>::type* = nullptr> + Status HandleNegate(HloInstruction* negate) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[negate], + ElementWiseUnaryOp( + negate, [](ElementwiseT elem_operand) { return -elem_operand; })); + return Status::OK(); + } + + Status HandleNegate(HloInstruction* negate) override { + return HandleNegate(negate); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleSign(HloInstruction* sign) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { + return (ElementwiseT(0) < elem_operand) - + (elem_operand < ElementwiseT(0)); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleSign(HloInstruction* sign) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { + auto abs_val = std::abs(elem_operand); + return 0 == abs_val ? ElementwiseT(0) + : elem_operand / abs_val; + })); + return Status::OK(); + } + + Status HandleSign(HloInstruction* sign) override { + return HandleSign(sign); + } + + template ::value>::type* = nullptr> + Status HandleAtan2(HloInstruction* atan2) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2], + ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return std::atan2(lhs_elem, rhs_elem); + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleAtan2(HloInstruction* atan2) { + return InvalidArgument("Unsupported type for Atan2"); + } + + Status HandleAtan2(HloInstruction* atan2) override { + return HandleAtan2(atan2); + } + + Status HandleTanh(HloInstruction* tanh) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], + ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) { + return std::tanh(elem_operand); + })); + return Status::OK(); + } + + template ::value && + !std::is_floating_point::value>::type* = nullptr> + Status HandleMultiply(HloInstruction* multiply) { + using type = typename std::make_unsigned::type; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[multiply], + ElementWiseBinaryOp(multiply, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return NativeT(type(lhs_elem) * type(rhs_elem)); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + std::is_floating_point::value || + is_complex_t::value>::type* = nullptr> + Status HandleMultiply(HloInstruction* multiply) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[multiply], + ElementWiseBinaryOp(multiply, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return lhs_elem * rhs_elem; + })); + return Status::OK(); + } + + Status HandleMultiply(HloInstruction* multiply) override { + return HandleMultiply(multiply); + } + + Status HandleSubtract(HloInstruction* subtract) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[subtract], + ElementWiseBinaryOp(subtract, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return lhs_elem - rhs_elem; + })); + return Status::OK(); + } + + Status HandleAdd(HloInstruction* add) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[add], + ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return lhs_elem + rhs_elem; + })); + return Status::OK(); + } + + Status HandleDivide(HloInstruction* divide) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], + ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return lhs_elem / rhs_elem; + })); + return Status::OK(); + } + + template ::value>::type* = + nullptr> + Status HandleMaximum(HloInstruction* maximum) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[maximum], + ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { + return std::max(lhs, rhs); + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleMaximum(HloInstruction* maximum) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[maximum], + ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { + return ((lhs >= rhs) || std::isnan(lhs)) ? lhs : rhs; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleMaximum(HloInstruction* maximum) { + return InvalidArgument("Unsupported type for Maximum"); + } + + Status HandleMaximum(HloInstruction* maximum) override { + return HandleMaximum(maximum); + } + + template ::value>::type* = + nullptr> + Status HandleMinimum(HloInstruction* minimum) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum], + ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::min(lhs_el, rhs_el); + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleMinimum(HloInstruction* minimum) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[minimum], + ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return ((lhs_el <= rhs_el) || std::isnan(lhs_el)) ? lhs_el : rhs_el; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleMinimum(HloInstruction* minimum) { + return InvalidArgument("Unsupported type for Minimum"); + } + + Status HandleMinimum(HloInstruction* minimum) override { + return HandleMinimum(minimum); + } + + Status HandlePower(HloInstruction* power) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[power], + ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::pow(lhs_el, rhs_el); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRemainder(HloInstruction* remainder) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], + ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::fmod(lhs_el, rhs_el); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRemainder(HloInstruction* remainder) { + return InvalidArgument("Unsupported type for Remainder"); + } + + Status HandleRemainder(HloInstruction* remainder) override { + return HandleRemainder(remainder); + } + + template ::value>::type* = + nullptr> + Status HandleAnd(HloInstruction* and_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[and_], + ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el & rhs_el; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleAnd(HloInstruction* and_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[and_], + ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el && rhs_el; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAnd(HloInstruction* and_) { + return InvalidArgument("Unsupported type for And"); + } + + Status HandleAnd(HloInstruction* and_) override { + return HandleAnd(and_); + } + + template ::value>::type* = + nullptr> + Status HandleOr(HloInstruction* or_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[or_], + ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el | rhs_el; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleOr(HloInstruction* or_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[or_], + ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el || rhs_el; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleOr(HloInstruction* or_) { + return InvalidArgument("Unsupported type for Or"); + } + + Status HandleOr(HloInstruction* or_) override { + return HandleOr(or_); + } + + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftLeft(HloInstruction* shl) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shl], + ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) { + return IsShiftOutOfBounds(rhs_elem) ? 0 + : (lhs_elem << rhs_elem); + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftLeft(HloInstruction*) { + return InvalidArgument("Unsupported type for ShiftLeft"); + } + + Status HandleShiftLeft(HloInstruction* shl) override { + return HandleShiftLeft(shl); + } + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftRightArithmetic(HloInstruction* shr) { + typedef typename std::make_signed::type SignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + SignedT lhs_signed = static_cast(lhs_elem); + if (IsShiftOutOfBounds(rhs_elem)) { + return lhs_signed < 0 ? static_cast(-1) : 0; + } else { + return lhs_signed >> rhs_elem; + } + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftRightArithmetic(HloInstruction*) { + return InvalidArgument("Unsupported type for ShiftRightArithmetic"); + } + + Status HandleShiftRightArithmetic(HloInstruction* shra) override { + return HandleShiftRightArithmetic(shra); + } + + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftRightLogical(HloInstruction* shr) { + typedef typename std::make_unsigned::type UnsignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + // If shift amount is greater than the number of bits, then return 0. + if (IsShiftOutOfBounds(rhs_elem)) { + return static_cast(0); + } + return static_cast(static_cast(lhs_elem) >> + rhs_elem); + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftRightLogical(HloInstruction*) { + return InvalidArgument("Unsupported type for ShiftRightLogical"); + } + + Status HandleShiftRightLogical(HloInstruction* shrl) override { + return HandleShiftRightLogical(shrl); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleClamp(HloInstruction* clamp) { + std::function + clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { + return std::fmin(high, std::fmax(value, low)); + }; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[clamp], + ElementwiseTernaryOp(clamp, + std::move(ConvertTernaryFunction(clamp_op)))); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleClamp(HloInstruction*) { + return InvalidArgument("Unsupported type for Clamp"); + } + + Status HandleClamp(HloInstruction* clamp) override { + return HandleClamp(clamp); + } + + Status HandleSelect(HloInstruction* select) override { + CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); + CHECK(!ShapeUtil::IsTuple(select->shape())); + std::function select_op = + [](bool pred, ReturnT on_true, ReturnT on_false) { + if (pred) { + return on_true; + } + return on_false; + }; + TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], + ElementwiseTernaryOp(select, std::move(select_op))); + return Status::OK(); + } + + Status HandleReverse(HloInstruction* reverse) override { + const auto result_shape = reverse->shape(); + const auto reverse_dimensions = reverse->dimensions(); + + auto operand = reverse->operand(0); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferReverseShape(operand->shape(), + reverse_dimensions)); + + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(result_shape) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + auto result = MakeUnique(result_shape); + + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice out_index) { + std::vector from_index(out_index.begin(), out_index.end()); + for (const int64 dim : reverse_dimensions) { + from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; + } + return operand_literal.Get(from_index); + })); + + parent_->evaluated_[reverse] = std::move(result); + return Status::OK(); + } + + Status HandleConvolution(HloInstruction* conv) override { + auto lhs = conv->operand(0); + auto rhs = conv->operand(1); + const auto& window = conv->window(); + const Shape& result_shape = conv->shape(); + const Shape& lhs_shape = lhs->shape(); + const Shape& rhs_shape = rhs->shape(); + + TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape)); + TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape)); + CHECK(ShapeUtil::IsArray(lhs_shape)); + CHECK(ShapeUtil::IsArray(rhs_shape)); + CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape)); + CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape)); + + const auto& dnums = conv->convolution_dimension_numbers(); + const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); + CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size()); + CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); + CHECK_GE(num_spatial_dims, 0); + CHECK_EQ(window.dimensions_size(), num_spatial_dims); + + const auto lhs_rank = ShapeUtil::Rank(lhs_shape); + const auto rhs_rank = ShapeUtil::Rank(rhs_shape); + + CHECK_EQ(num_spatial_dims + 2, lhs_rank); + CHECK_EQ(num_spatial_dims + 2, rhs_rank); + + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, + window, dnums)); + CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(result_shape) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + std::vector window_dimension_sizes; + for (auto i : dnums.kernel_spatial_dimensions()) { + window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i)); + } + + const Shape& window_shape = + ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes); + + DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape); + DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape); + + auto lhs_literal_data = lhs_literal.data(); + auto rhs_literal_data = rhs_literal.data(); + + auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, + &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, + rhs_literal_data]( + tensorflow::gtl::ArraySlice out_index) { + // Dimension number applicable for input (lhs). + const int64 input_batch_dim = dnums.input_batch_dimension(); + const int64 input_z_dim = dnums.input_feature_dimension(); + // Dimension number applicable for kernel (rhs). + const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); + const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); + // Dimension number applicable for output. + const int64 output_batch_dim = dnums.output_batch_dimension(); + const int64 output_z_dim = dnums.output_feature_dimension(); + + const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); + + ElementwiseT result_val = static_cast(0); + DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), + 0); + + // Convolve input feature with kernel. + do { + for (int64 iz = 0; iz < z_size; ++iz) { + int64 lhs_linear_index = 0; + lhs_linear_index += out_index[output_batch_dim] * + lhs_dim_multipliers[input_batch_dim]; + lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; + + int64 rhs_linear_index = 0; + rhs_linear_index += out_index[output_z_dim] * + rhs_dim_multipliers[kernel_output_z_dim]; + rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim]; + + // Find corresponding spatial dimension index for input (lhs). + for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { + // Spatial dimension number for input (lhs) and output. + const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); + const int64 output_spatial_dim = + dnums.output_spatial_dimensions(ki); + + // Calculate lhs (input) index without taking base dilation into + // account. + const auto& window_dim = window.dimensions(ki); + const int64 undilated_index = + out_index[output_spatial_dim] * window_dim.stride() - + window_dim.padding_low() + + rhs_spatial_index[ki] * window_dim.window_dilation(); + // Skip if the lhs (input) index is to be dilated. As an + // optimization, skip this mod if there's no dilation. + if (window_dim.base_dilation() > 1 && + undilated_index % window_dim.base_dilation() != 0) { + goto cnt; + } + + // Calculate the actual lhs (input) index after dilation. As an + // optimization, skip this integer divide if there's no dilation. + int64 lhs_spatial_index; + if (window_dim.base_dilation() > 1) { + lhs_spatial_index = undilated_index / window_dim.base_dilation(); + } else { + lhs_spatial_index = undilated_index; + } + lhs_linear_index += + lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim]; + + // Skip if input index is not in bounds. + if (!(lhs_spatial_index >= 0 && + lhs_spatial_index < + lhs_shape.dimensions(input_spatial_dim))) { + goto cnt; + } + + rhs_linear_index += + (window_dim.window_reversal() + ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) + : rhs_spatial_index[ki]) * + rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)]; + } + + result_val += + static_cast(lhs_literal_data[lhs_linear_index]) * + static_cast(rhs_literal_data[rhs_linear_index]); + } + cnt : {} + } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); + + return static_cast(result_val); + }; + + auto result = MakeUnique(result_shape); + TF_RETURN_IF_ERROR(result->PopulateParallel(func)); + + parent_->evaluated_[conv] = std::move(result); + return Status::OK(); + } + + Status HandleDot(HloInstruction* dot) override { + auto lhs = dot->operand(0); + auto rhs = dot->operand(1); + CHECK(ShapeUtil::IsArray(dot->shape())); + CHECK(ShapeUtil::IsArray(lhs->shape())); + CHECK(ShapeUtil::IsArray(rhs->shape())); + + const auto& dnums = dot->dot_dimension_numbers(); + + const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); + const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); + + CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); + CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); + + // There must be 1 and only 1 Contracting dimension for lhs and rhs. + CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); + CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); + const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); + // Contracted dimension sizes must be the same. + CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension), + rhs->shape().dimensions(rhs_contracting_dimension)) + << "lhs contracted dimension: " + << lhs->shape().dimensions(lhs_contracting_dimension) + << " rhs contracted dimension: " + << rhs->shape().dimensions(rhs_contracting_dimension); + const int64 contracted_dimension_size = + lhs->shape().dimensions(lhs_contracting_dimension); + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + CHECK_EQ(dnums.lhs_batch_dimensions_size(), + dnums.rhs_batch_dimensions_size()); + + std::vector lhs_non_contracting_dims; + for (int64 i = 0; i < lhs_rank; i++) { + if (i != lhs_contracting_dimension) { + lhs_non_contracting_dims.push_back(i); + } + } + + std::vector rhs_non_batch_non_contracting_dims; + tensorflow::gtl::FlatSet batch_dims_set( + dnums.rhs_batch_dimensions().begin(), + dnums.rhs_batch_dimensions().end()); + for (int64 i = 0; i < rhs_rank; i++) { + if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { + rhs_non_batch_non_contracting_dims.push_back(i); + } + } + + const int64 batch_dim_size = dnums.lhs_batch_dimensions_size(); + const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size(); + + DimensionVector lhs_index(lhs_rank); + DimensionVector rhs_index(rhs_rank); + auto result = MakeUnique(dot->shape()); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice result_index) { + ElementwiseT result_val = static_cast(0); + + // Find the corresponding non-contracting indices for lhs and rhs. + // + // For `result_index`, its batch dimension, if exists, will be at the + // same dimension as the batch dimension of lhs and rhs. More + // specifically: + // - For lhs, the non-contracting dimensions, including the batch + // dimension have the same index as the `result_index`. + // - For rhs, the batch dimension is set seperately from other + // non-contracting dimensions, since these other non-contracting + // dimensions in rhs follow the non-contracting dimensions of lhs in + // the resulting index. + // + // As an example, for a resulting index: + // result_index [result_batch, result_x, result_y] + // the effecting lhs and rhs indices are: + // lhs [result_batch, lhs_non_contracting_dim, contracting_dim + // rhs [result_batch, contracting_dim, rhs_non_contracting_dim] + // `result_x` is only affected by the lhs_non_contracting_dim and + // likewise `result_y` only depends on rhs_non_contracting_dim. + // + // so we can look up the lhs and rhs indices by: + // + // lhs: + // batch index is the same as `result_batch`. + // non-contracting dimension is the same as + // result_index[lhs_non_contracting_dim] + // rhs: + // batch index: the same as `result_batch`. + // non-contracting dimension index: *not* the same as + // result_index[rhs_non_contractng_dim], since the + // non-contracting dimensions of lhs are included in the + // result_index first. Instead, the non_contracting_dim of rhs must + // be calculated as following: + // lhs_non_contracting_dimensions_size + + // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1 + // + // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is + // the index offset to the result_index that only depends on + // the non_batch and non-contracting dimensions of rhs. -1 at the + // end translates size to index. + for (auto i : lhs_non_contracting_dims) { + lhs_index[i] = result_index[i]; + } + for (auto i : dnums.rhs_batch_dimensions()) { + rhs_index[i] = result_index[i]; + } + for (auto i : rhs_non_batch_non_contracting_dims) { + const int64 rhs_non_batch_non_contracting_dim = + lhs_non_contracting_size + (i - batch_dim_size) - 1; + rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim]; + } + + // Accumulates resulting product along the contracted dimension. + for (int64 i = 0; i < contracted_dimension_size; ++i) { + lhs_index[lhs_contracting_dimension] = i; + rhs_index[rhs_contracting_dimension] = i; + + result_val += + static_cast(lhs_literal.Get(lhs_index)) * + static_cast(rhs_literal.Get(rhs_index)); + } + + return static_cast(result_val); + })); + + parent_->evaluated_[dot] = std::move(result); + return Status::OK(); + } + + Status HandlePad(HloInstruction* pad) override { + CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape())); + // Padding value must be scalar. + CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); + CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), + pad->padding_config().dimensions_size()); + + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferPadShape( + /*operand_shape=*/pad->operand(0)->shape(), + /*padding_value_shape=*/pad->operand(1)->shape(), + /*padding_config=*/pad->padding_config())); + CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(pad->shape()) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + // Create new HLO of padded shape with padding value. + ReturnT scalar = + parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); + auto result = MakeUnique(pad->shape()); + TF_RETURN_IF_ERROR(result->Populate( + [&scalar](tensorflow::gtl::ArraySlice multi_index) { + return scalar; + })); + + const Literal& evaluated_operand = + parent_->GetEvaluatedLiteralFor(pad->operand(0)); + + std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), + 0); + std::vector target_index(ShapeUtil::Rank(result->shape()), 0); + + // Loop through each element of the operand, assign them to the + // corresponding index of the resulting padded literal. + const PaddingConfig& pad_config = pad->padding_config(); + + auto func = [&](tensorflow::gtl::ArraySlice input_index) { + for (auto i = 0; i < input_index.size(); ++i) { + // Interior padding occurs logically before edge padding, so in the case + // of negative edge padding elements are removed from the + // interior-padded operand. + target_index[i] = + pad_config.dimensions(i).edge_padding_low() + + input_index[i] * (pad_config.dimensions(i).interior_padding() + 1); + + // Account for negative low and high padding: skip assignment if the + // any target index is out of range. + if (!(target_index[i] >= 0 && + target_index[i] < pad->shape().dimensions(i))) { + return true; + } + } + result->Set(target_index, + evaluated_operand.Get(input_index)); + return true; + }; + + std::vector zero_base(evaluated_operand.shape().dimensions_size(), + 0); + std::vector step(evaluated_operand.shape().dimensions_size(), 1); + + ShapeUtil::ForEachIndex( + evaluated_operand.shape(), zero_base, + AsInt64Slice(evaluated_operand.shape().dimensions()), step, func); + + parent_->evaluated_[pad] = std::move(result); + return Status::OK(); + } + + Status HandleDynamicSlice(HloInstruction* dynamic_slice) override { + auto operand = dynamic_slice->operand(0); + auto start_indices = dynamic_slice->operand(1); + auto result_shape = dynamic_slice->shape(); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferDynamicSliceShape( + operand->shape(), start_indices->shape(), + dynamic_slice->dynamic_slice_sizes())); + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(result_shape) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + TF_RET_CHECK( + primitive_util::IsIntegralType(start_indices->shape().element_type())); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& start_indices_literal = + parent_->GetEvaluatedLiteralFor(start_indices); + + switch (start_indices->shape().element_type()) { + case S32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + case S64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + case U32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + case U64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + default: + LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for " + "start_indices: " + << PrimitiveType_Name(start_indices->shape().element_type()); + } + + return Status::OK(); + } + + Status HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) override { + auto operand = dynamic_update_slice->operand(0); + auto update = dynamic_update_slice->operand(1); + auto start_indices = dynamic_update_slice->operand(2); + auto result_shape = dynamic_update_slice->shape(); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferDynamicUpdateSliceShape( + operand->shape(), update->shape(), start_indices->shape())); + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(result_shape) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + TF_RET_CHECK( + primitive_util::IsIntegralType(start_indices->shape().element_type())); + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape())); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update); + const Literal& start_indices_literal = + parent_->GetEvaluatedLiteralFor(start_indices); + + switch (start_indices->shape().element_type()) { + case S32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + case S64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + case U32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + case U64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + default: + LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for " + "start_indices: " + << PrimitiveType_Name(start_indices->shape().element_type()); + } + + return Status::OK(); + } + + template + StatusOr> MapImpl(HloInstruction* map) { + auto operands = map->operands(); + HloComputation* computation = map->to_apply(); + + auto result = MakeUnique(map->shape()); + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + std::vector> arg_literals; + arg_literals.reserve(operands.size()); + + // Construct scalar literal parameters to be passed to the map + // computation. + for (auto operand : operands) { + const Literal& arg_literal = + parent_->GetEvaluatedLiteralFor(operand); + + auto curr_val = arg_literal.Get(multi_index); + auto curr_val_literal = Literal::CreateR0(curr_val); + + arg_literals.push_back(std::move(curr_val_literal)); + } + + std::unique_ptr computed_result = + embedded_evaluator + .Evaluate>(*computation, + arg_literals) + .ConsumeValueOrDie(); + // Clear visit states so that the we can use the evaluate again on + // the same computation. + embedded_evaluator.ResetVisitStates(); + + return computed_result->Get({}); + })); + return std::move(result); + } + + Status HandleMap(HloInstruction* map) override { + switch (map->operand(0)->shape().element_type()) { + case PRED: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U8: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S8: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case F16: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], + MapImpl(map)); + break; + } + case F32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case F64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case C64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + default: + LOG(FATAL) << "HandleMap: unhandled primitive type for " + "input operand: " + << PrimitiveType_Name( + map->operand(0)->shape().element_type()); + } + + return Status::OK(); + } + + Status HandleReduce(HloInstruction* reduce) override { + auto arg = reduce->operand(0); + auto init_value = reduce->operand(1); + tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + HloComputation* function = reduce->to_apply(); + TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) == + ShapeUtil::Rank(arg->shape()) - dimensions.size()); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferReduceShape( + /*arg=*/arg->shape(), + /*init_value=*/init_value->shape(), + /*dimensions_to_reduce=*/dimensions, + /*to_apply=*/function->ComputeProgramShape())); + TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape()) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); + VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString(); + const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value); + VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString(); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get({}); + + const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions()); + std::vector arg_dim_steps(arg_dimensions.size()); + std::vector arg_dim_counts(arg_dimensions.size()); + for (const int64 dim : dimensions) { + arg_dim_steps[dim] = 1; + arg_dim_counts[dim] = arg_dimensions[dim]; + } + + // Map each dimension in the result to a dimension in arg that isn't + // being reduced. + std::vector result_to_arg_index; + for (int64 i = 0; i < arg_dimensions.size(); ++i) { + if (arg_dim_steps[i] == 0) { + result_to_arg_index.push_back(i); + } + } + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + auto result = MakeUnique(reduce->shape()); + // For each resulting dimension, calculate and assign computed value. + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + ReturnT result_val = init_scalar; + + std::vector base(arg_dimensions.size()); + for (int64 i = 0; i < multi_index.size(); ++i) { + base[result_to_arg_index[i]] = multi_index[i]; + } + + // When the reduction is addition of floats, accumulate in a double + // for better precision. Also, avoid creating Literals for the + // intermediate results; it's much faster. + if (ShapeUtil::ElementIsFloating(init_literal.shape()) && + IsScalarAdd(function)) { + double computed_result = 0; + auto func = [&](tensorflow::gtl::ArraySlice input_index) { + computed_result += arg_literal.Get(input_index); + return true; + }; + ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, + arg_dim_steps, func); + return static_cast(computed_result); + } + auto func = [&](tensorflow::gtl::ArraySlice input_index) { + auto curr_val = arg_literal.Get(input_index); + + // Evaluate computation with specified literal operands. + auto curr_val_literal = Literal::CreateR0(curr_val); + auto result_val_literal = Literal::CreateR0(result_val); + std::vector args = {result_val_literal.get(), + curr_val_literal.get()}; + + std::unique_ptr computed_result = + embedded_evaluator.Evaluate(*function, args) + .ConsumeValueOrDie(); + // Clear visit states so that we can use the evaluator again on + // the same computation. + embedded_evaluator.ResetVisitStates(); + // Assign computed result to result_val. + result_val = computed_result->Get({}); + return true; + }; + // Computes one element of the result, reducing all dimensions that + // contribute to that element. + ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, + arg_dim_steps, func); + return result_val; + })); + + parent_->evaluated_[reduce] = std::move(result); + return Status::OK(); + } + + bool IsScalarAdd(HloComputation* computation) { + HloInstruction* instruction = computation->root_instruction(); + if (instruction->opcode() == HloOpcode::kAdd && + computation->num_parameters() == 2) { + const HloInstruction* lhs = instruction->operand(0); + const HloInstruction* rhs = instruction->operand(1); + return lhs->opcode() == HloOpcode::kParameter && + ShapeUtil::IsScalar(lhs->shape()) && + rhs->opcode() == HloOpcode::kParameter && + ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs; + } + return false; + } + + Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { + auto operand = select_and_scatter->operand(0); + auto source = select_and_scatter->operand(1); + const Window& window = select_and_scatter->window(); + + const Literal& init_literal = + parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2)); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get({}); + + auto result = MakeUnique(select_and_scatter->shape()); + + // Initialize result array with the init value. + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice output_index) { + return init_scalar; + })); + + std::vector window_dimension_sizes; + for (const auto& window_dimension : window.dimensions()) { + window_dimension_sizes.push_back(window_dimension.size()); + } + const Shape window_shape = ShapeUtil::MakeShape( + operand->shape().element_type(), window_dimension_sizes); + + HloComputation* select = select_and_scatter->select(); + HloComputation* scatter = select_and_scatter->scatter(); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source); + + int64 rank = ShapeUtil::Rank(operand_literal.shape()); + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + DimensionVector source_index(rank, 0); + + // Used in the dual IterateThroughWindow lambdas below. Hoisted to avoid + // dynamic memory allocations. + auto curr_val_literal = Literal::CreateR0(ReturnT()); + auto selected_val_literal = Literal::CreateR0(ReturnT()); + auto source_literal_scatter = Literal::CreateR0(ReturnT()); + auto scattered_literal = Literal::CreateR0(ReturnT()); + do { + // For each element in `source`, we place a window in `operand`. For each + // window placement, we iterate inside the window twice: + // + // 1. Find the selected index by applying `select` function to all + // elements. E.g., If the `select` function is GreaterEqual, the first + // iteration through the window finds the biggest value and returns its + // index. + // + // 2. Using the selected index, scatter value from `source` to result. We + // do this by iterating through the window, and compare each index with + // the selected index. + tensorflow::gtl::optional selected_val; + tensorflow::gtl::optional> selected_index; + + IterateThroughWindow( + window_shape, window, operand_literal.shape(), source_index, + [&](const std::vector& operand_index) { + auto curr_val = operand_literal.Get(operand_index); + if (!selected_val) { + selected_val = curr_val; + selected_index = operand_index; + } + curr_val_literal->Set({}, curr_val); + selected_val_literal->Set({}, *selected_val); + std::unique_ptr computed_result = + embedded_evaluator + .Evaluate( + *select, + {selected_val_literal.get(), curr_val_literal.get()}) + .ConsumeValueOrDie(); + bool selected = !computed_result->Get({}); + if (selected) { + selected_val = curr_val; + selected_index = operand_index; + } + embedded_evaluator.ResetVisitStates(); + }); + + IterateThroughWindow( + window_shape, window, operand_literal.shape(), source_index, + [&](const std::vector& operand_index) { + if (std::equal(operand_index.begin(), operand_index.end(), + selected_index->begin())) { + auto source = source_literal.Get(source_index); + auto scattered = result->Get(operand_index); + source_literal_scatter->Set({}, source); + scattered_literal->Set({}, scattered); + std::unique_ptr computed_result = + embedded_evaluator + .Evaluate(*scatter, + {source_literal_scatter.get(), + scattered_literal.get()}) + .ConsumeValueOrDie(); + result->Set(operand_index, computed_result->Get({})); + // Clear visit states so that the we can use the evaluator again + // on the same computation. + embedded_evaluator.ResetVisitStates(); + } + }); + } while (IndexUtil::BumpIndices(source->shape(), &source_index)); + + parent_->evaluated_[select_and_scatter] = std::move(result); + return Status::OK(); + } + + Status HandleReduceWindow(HloInstruction* reduce_window) override { + auto operand = reduce_window->operand(0); + const Window& window = reduce_window->window(); + HloComputation* function = reduce_window->to_apply(); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferReduceWindowShape( + /*operand_shape=*/reduce_window->operand(0)->shape(), + /*init_value=*/reduce_window->operand(1)->shape(), window, + /*to_apply_shape=*/function->ComputeProgramShape())); + TF_RET_CHECK( + ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) + << "return shape is set to: " + << ShapeUtil::HumanStringWithLayout(reduce_window->shape()) + << "but is inferred to be: " + << ShapeUtil::HumanStringWithLayout(inferred_return_shape); + + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(reduce_window->operand(0)); + VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString(); + const Literal& init_literal = + parent_->GetEvaluatedLiteralFor(reduce_window->operand(1)); + VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString(); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get({}); + + // Creates a Shape object from window, for iteration below. + std::vector window_dimension_sizes; + for (const auto& window_dimension : window.dimensions()) { + window_dimension_sizes.push_back(window_dimension.size()); + } + const Shape window_shape = ShapeUtil::MakeShape( + operand->shape().element_type(), window_dimension_sizes); + + DimensionVector window_index(window.dimensions_size()); + DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + auto result = MakeUnique(reduce_window->shape()); + // For each resulting dimension, calculate and assign computed value. + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice output_index) { + ReturnT result_val = init_scalar; + + std::fill(window_index.begin(), window_index.end(), 0); + std::fill(operand_index.begin(), operand_index.end(), 0); + + IterateThroughWindow( + window_shape, window, operand_literal.shape(), output_index, + [&](const std::vector& operand_index) { + auto curr_val = operand_literal.Get(operand_index); + + // Evaluate computation with specified literal operands. + const auto curr_val_literal = + Literal::CreateR0(curr_val); + const auto result_val_literal = + Literal::CreateR0(result_val); + const std::vector args = { + result_val_literal.get(), curr_val_literal.get()}; + std::unique_ptr computed_result = + embedded_evaluator.Evaluate(*function, args) + .ConsumeValueOrDie(); + + // Clear visit states so that the we can use the evaluate again + // on the same computation. + embedded_evaluator.ResetVisitStates(); + + result_val = computed_result->Get({}); + }); + + return result_val; + })); + + parent_->evaluated_[reduce_window] = std::move(result); + return Status::OK(); + } + + Status HandleSlice(HloInstruction* slice) override { + auto operand = slice->operand(0); + const Shape& shape = slice->shape(); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferSliceShape( + operand->shape(), slice->slice_starts(), + slice->slice_limits(), slice->slice_strides())); + TF_RET_CHECK(ShapeUtil::Compatible(shape, inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(shape) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const int64 rank = ShapeUtil::Rank(operand->shape()); + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + auto func = [&](tensorflow::gtl::ArraySlice out_index) { + DimensionVector operand_index(rank); + for (int64 i = 0; i < rank; ++i) { + operand_index[i] = + slice->slice_starts(i) + out_index[i] * slice->slice_strides(i); + } + return operand_literal.Get(operand_index); + }; + + auto result = Literal::CreateFromDimensions( + shape.element_type(), AsInt64Slice(shape.dimensions())); + TF_RETURN_IF_ERROR(result->Populate(func)); + parent_->evaluated_[slice] = std::move(result); + return Status::OK(); + } + + // Enable CLZ only for int32, uint32, int64 and uint64. + template < + typename NativeT, + typename std::enable_if< + (std::is_floating_point::value || + std::is_integral::value || is_complex_t::value) && + !(std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value)>::type* = nullptr> + Status HandleClz(HloInstruction* clz) { + return InvalidArgument("Unsupported type for Clz"); + } + + template ::value || + std::is_same::value>::type* = nullptr> + Status HandleClz(HloInstruction* clz) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz], + ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) { + return 31 - tensorflow::Log2Floor(elem_operand); + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = nullptr> + Status HandleClz(HloInstruction* clz) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz], + ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) { + return 63 - tensorflow::Log2Floor64(elem_operand); + })); + return Status::OK(); + } + + Status HandleClz(HloInstruction* clz) override { + return HandleClz(clz); + } + + template ::value>::type* = nullptr> + Status HandleSin(HloInstruction* sin) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin], + ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) { + return std::sin(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleSin(HloInstruction* sin) { + return InvalidArgument("Unsupported type for Sin"); + } + + Status HandleSin(HloInstruction* sin) override { + return HandleSin(sin); + } + + template ::value>::type* = nullptr> + Status HandleCos(HloInstruction* cos) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos], + ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) { + return std::cos(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleCos(HloInstruction* cos) { + return InvalidArgument("Unsupported type for Cos"); + } + + Status HandleCos(HloInstruction* cos) override { + return HandleCos(cos); + } + + template ::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[reduce_precision], + ElementWiseUnaryOp(reduce_precision, [reduce_precision]( + ElementwiseT elem) { + uint32_t value_as_int = tensorflow::bit_cast(elem); + const uint32_t mantissa_bits = reduce_precision->mantissa_bits(); + const uint32_t exponent_bits = reduce_precision->exponent_bits(); + + // Code is based on the CPU/GPU implementation in LLVM-emitting code. + // + // Bits in float type: + // mantissa : bits [0:22] + // exponent : bits [23:30] + // sign : bits [31] + if (mantissa_bits < 23) { + const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); + + // Compute rounding bias for round-to-nearest with ties to even. + // This is equal to a base value of 0111... plus one bit if the last + // remaining mantissa bit is 1. + const uint32_t base_rounding_bias = + (last_mantissa_bit_mask >> 1) - 1; + const uint32_t x_last_mantissa_bit = + (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits); + const uint32_t x_rounding_bias = + x_last_mantissa_bit + base_rounding_bias; + + // Add rounding bias, and mask out truncated bits. Note that the + // case where adding the rounding bias overflows into the exponent + // bits is correct; the non-masked mantissa bits will all be zero, + // and the exponent will be incremented by one. + const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); + value_as_int = value_as_int + x_rounding_bias; + value_as_int = value_as_int & truncation_mask; + } + if (exponent_bits < 8) { + // Masks for f32 values. + const uint32_t f32_sign_bit_mask = 1u << 31; + const uint32_t f32_exp_bits_mask = 0xffu << 23; + + // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the + // most- significant bit -- is equal to 1.0f for all exponent sizes. + // Adding 2^(n-1)-1 to this gives us the highest non-infinite + // exponent for a bit- size of n, and subtracting 2^(n-1)-1 from + // this gives us the lowest' exponent (corresponding to 0.0f). + // + // Thus, the f32 exponent corresponding to the highest non-infinite + // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 + // exponent corresponding to the lowest exponent for a bit size of n + // is (2^7-1) - 2^(n-1)-1. + // + // Note that we have already checked that exponents_bits >= 1. + const uint32_t f32_exponent_bias = (1 << 7) - 1; + const uint32_t reduced_exponent_bias = + (1 << (exponent_bits - 1)) - 1; + const uint32_t reduced_max_exponent = + f32_exponent_bias + reduced_exponent_bias; + const uint32_t reduced_min_exponent = + f32_exponent_bias - reduced_exponent_bias; + + // Do we overflow or underflow? + const uint32_t x_exponent = value_as_int & f32_exp_bits_mask; + const bool x_overflows = x_exponent > (reduced_max_exponent << 23); + const bool x_underflows = + x_exponent <= (reduced_min_exponent << 23); + + // Compute appropriately-signed values of zero and infinity. + const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask; + const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask; + + // Force to zero or infinity if overflow or underflow. (Note that + // this truncates all denormal values to zero, rather than rounding + // them.) + value_as_int = x_overflows ? x_signed_inf : value_as_int; + value_as_int = x_underflows ? x_signed_zero : value_as_int; + } + + float reduced_result = tensorflow::bit_cast(value_as_int); + if (std::isnan(elem)) { + reduced_result = mantissa_bits > 0 + ? elem + : std::numeric_limits::infinity(); + } + return reduced_result; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + return InvalidArgument("Double not supported for reduce precision"); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + return InvalidArgument("Unsupported type for reduce precision"); + } + + Status HandleReducePrecision(HloInstruction* reduce_precision) override { + return HandleReducePrecision(reduce_precision); + } + + private: + // Creates a vector of multipliers which can be used to create a linear index + // into shape. + // + // Given the multidimensional index {i1, ..., iN} and + // M = MakeDimMultipliers(shape), the corresponding linear index LI is simply + // + // LI = i1 * M[1] + i2 * M[2] + ... + iN * M[N]. + // + // This lets you calculate LI given the multidimensional indices in any order. + static DimensionVector MakeDimMultipliers(const Shape& shape) { + DimensionVector v(ShapeUtil::Rank(shape)); + int64 scale = 1; + for (auto dim : LayoutUtil::MinorToMajor(shape)) { + v[dim] = scale; + scale *= shape.dimensions(dim); + } + return v; + } + + // For one particular placement of a window in a base shape (the placement is + // represented as `window_count_index`), iterates inside the window. + // Translates the window index into base index. If the base index is within + // bound, call `f` with the base index. + static void IterateThroughWindow( + const Shape& window_shape, const Window& window, const Shape& base_shape, + const tensorflow::gtl::ArraySlice& window_count_index, + const std::function&)>& f) { + const int64 rank = ShapeUtil::Rank(base_shape); + DimensionVector window_index(rank); + std::fill(window_index.begin(), window_index.end(), 0); + do { + std::vector base_index(rank); + bool out_of_bound = false; + for (int64 i = 0; i < rank; ++i) { + base_index[i] = window_count_index[i] * window.dimensions(i).stride() + + window_index[i] - window.dimensions(i).padding_low(); + if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { + out_of_bound = true; + break; + } + } + if (!out_of_bound) { + f(base_index); + } + } while (IndexUtil::BumpIndices(window_shape, &window_index)); + } + + template + StatusOr> DynamicSlice( + const Literal& operand_literal, const Literal& start_indices_literal, + const Shape& result_shape) { + auto start_indices_typed = start_indices_literal.data(); + std::vector start(start_indices_typed.begin(), + start_indices_typed.end()); + + // Clamp the start indices so the slice is in-bounds w.r.t the operand. + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + for (int64 i = 0; i < start.size(); ++i) { + start[i] = std::min( + std::max(0LL, start[i]), + operand_literal.shape().dimensions(i) - result_shape.dimensions(i)); + } + + std::vector operand_indices(start.size()); + auto result = MakeUnique(result_shape); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + for (int64 i = 0; i < operand_indices.size(); ++i) { + CHECK_GE(multi_index[i] + start[i], 0); + operand_indices[i] = multi_index[i] + start[i]; + } + + auto result = operand_literal.Get(operand_indices); + return result; + })); + + return std::move(result); + } + + template + StatusOr> DynamicUpdateSlice( + const Literal& operand_literal, const Literal& update_literal, + const Literal& start_indices_literal) { + auto result = operand_literal.CloneToUnique(); + auto start_indices_typed = start_indices_literal.data(); + const auto rank = ShapeUtil::Rank(result->shape()); + std::vector start(start_indices_typed.begin(), + start_indices_typed.end()); + // Clamp the update start indices so the slice is in-bounds w.r.t the + // operand. + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + for (int64 i = 0; i < rank; ++i) { + start[i] = std::min( + std::max(0, start[i]), + result->shape().dimensions(i) - update_literal.shape().dimensions(i)); + } + std::vector result_index(rank, 0); + + auto func = [&](tensorflow::gtl::ArraySlice update_index) { + std::transform(update_index.begin(), update_index.end(), start.begin(), + result_index.begin(), std::plus()); + result->Set(result_index, + update_literal.Get(update_index)); + return true; + }; + + std::vector base(update_literal.shape().dimensions_size(), 0); + std::vector step(update_literal.shape().dimensions_size(), 1); + ShapeUtil::ForEachIndex(update_literal.shape(), base, + AsInt64Slice(update_literal.shape().dimensions()), + step, func); + + return std::move(result); + } + + StatusOr> ElementWiseUnaryOp( + HloInstruction* instruction, + const std::function& unary_op) { + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(instruction->operand(0)); + TF_ASSIGN_OR_RETURN( + auto result_literal, + (HloEvaluator::ElementWiseUnaryOpImpl( + instruction, ConvertUnaryFunction(unary_op), operand_literal))); + + return std::move(result_literal); + } + + StatusOr> ElementWiseBinaryOp( + HloInstruction* instruction, + const std::function& + binary_op) { + const auto shape = instruction->shape(); + const auto* lhs = instruction->operand(0); + const auto* rhs = instruction->operand(1); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast + // is removed. + if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s: ", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str()); + } + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + auto result = MakeUnique(shape); + + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return ConvertBinaryFunction(binary_op)( + lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); + return std::move(result); + } + + template + StatusOr> ElementwiseTernaryOp( + HloInstruction* instruction, + const std::function& ternary_op) { + const auto shape = instruction->shape(); + const auto* lhs = instruction->operand(0); + const auto* rhs = instruction->operand(1); + const auto* ehs = instruction->operand(2); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit + // broadcast is removed. + if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && + ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s vs %s: ", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str(), + ShapeUtil::HumanString(ehs->shape()).c_str()); + } + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); + + auto result = MakeUnique(shape); + + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return ternary_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index), + ehs_literal.Get(multi_index)); + })); + + return std::move(result); + } + + template + static bool IsShiftOutOfBounds(NativeT rhs) { + typedef typename std::make_unsigned::type UnsignedT; + UnsignedT lhs_size_unsigned = sizeof(NativeT) * CHAR_BIT; + UnsignedT rhs_unsigned = static_cast(rhs); + return rhs_unsigned >= lhs_size_unsigned; + } + + HloEvaluator* parent_; +}; + +// These extern templates prevent users of this class from implicitly +// instantiating it. We explicitly instantiate this class in the various +// hlo_evaluator_typed_visitor*.cc files. +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bfloat16.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bfloat16.cc new file mode 100644 index 0000000000000000000000000000000000000000..39c352dfb966af4ad9f1874d078b92dd2a321783 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bfloat16.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bool.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bool.cc new file mode 100644 index 0000000000000000000000000000000000000000..289b40fa06d37b8f5b2705e7de2f479c4a30e89d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bool.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex64.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex64.cc new file mode 100644 index 0000000000000000000000000000000000000000..9cb4eb921fd3af566de5998a097423c90f0cb860 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex64.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_double.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_double.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e6252fbf8c24a7b79c7e656040a6be7be8d777f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_double.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_float.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_float.cc new file mode 100644 index 0000000000000000000000000000000000000000..ee793ae77b1b432daece31697ad436de1683bc08 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_float.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_half.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_half.cc new file mode 100644 index 0000000000000000000000000000000000000000..038d9d39e4a5881b9f0fb1d98732132aab3aaa2c --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_half.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int32.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int32.cc new file mode 100644 index 0000000000000000000000000000000000000000..b1952ca6193958eec49fd15297f73a6c6ac22b83 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int32.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int64.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int64.cc new file mode 100644 index 0000000000000000000000000000000000000000..0cbaffb40b7128fb6e99308fbc2b48e63a3d6fac --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int64.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int8.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int8.cc new file mode 100644 index 0000000000000000000000000000000000000000..6f4bf2a392b51abc4d37db4beab6d1ea2b0c4e3a --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int8.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint32.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint32.cc new file mode 100644 index 0000000000000000000000000000000000000000..10235447e0d266a6071097e38913c3856939509b --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint32.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint64.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint64.cc new file mode 100644 index 0000000000000000000000000000000000000000..8abeaa6ffca4409d2664de6f55850622e95bbc9d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint64.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint8.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint8.cc new file mode 100644 index 0000000000000000000000000000000000000000..6dabd1c176eabcf6656d6de9683bbf0131456d96 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint8.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index a0cb28246d3be541e798e85552436f64a3521f22..4900c813fdf037e65c6b42d027f1cbefb6ee9830 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -16,52 +16,32 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { -class HloExecutionProfileTest : public HloTestBase { - protected: - static constexpr int64 kInstructionCyclesIndex = 0; - static constexpr int64 kInstructionNameIndex = 19; -}; +using tensorflow::strings::StrCat; +using ::testing::AllOf; +using ::testing::ContainsRegex; -// Splits `lines` into a sequence of lines delimited by newlines and then split -// each of those lines into a sequence of words delimited by spaces. Filter out -// empty words. -std::vector> SplitIntoLinesAndWords( - tensorflow::StringPiece lines) { - std::vector> result; - for (const string& line : tensorflow::str_util::Split(lines, '\n')) { - std::vector words; - for (const string& word : tensorflow::str_util::Split(line, ' ')) { - if (!word.empty()) { - words.push_back(word); - } - } - result.push_back(std::move(words)); - } - - return result; -} +class HloExecutionProfileTest : public HloTestBase {}; TEST_F(HloExecutionProfileTest, Basic) { - std::unique_ptr hlo_module = CreateNewModule(); - - HloComputation::Builder builder(TestName()); + auto hlo_module = tools::Parse(R"( + HloModule test_module + ENTRY entry_computation { + lhs = f32[30,30]{1,0} parameter(0) + rhs = f32[30,30]{1,0} parameter(1) + add = f32[30,30]{1,0} add(lhs, rhs) + ROOT dot = f32[30,30]{1,0} dot(lhs, add), lhs_contracting_dims={1}, rhs_contracting_dims={0} + })") + .ValueOrDie(); + const HloInstruction* dot_instruction = + hlo_module->entry_computation()->root_instruction(); + const HloInstruction* add_instruction = dot_instruction->operand(1); Shape shape = ShapeUtil::MakeShape(F32, {30, 30}); - HloInstruction* param_lhs = - builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs")); - HloInstruction* param_rhs = - builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs")); - HloInstruction* add_instruction = - builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kAdd, param_lhs, param_rhs)); - HloInstruction* dot_instruction = - builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, param_lhs, add_instruction)); - - hlo_module->AddEntryComputation(builder.Build()); auto shape_size_function = [&](const Shape& shape) { const int64 pointer_size = 8; @@ -84,20 +64,12 @@ TEST_F(HloExecutionProfileTest, Basic) { execution_profile.SetCyclesTakenBy(add_instruction, add_cycles); execution_profile.SetCyclesTakenBy(dot_instruction, dot_cycles); - string rendered_profile = execution_profile.ToString( - backend().default_stream_executor()->GetDeviceDescription()); - std::vector> lines_and_words = - SplitIntoLinesAndWords(rendered_profile); - ASSERT_EQ(lines_and_words.size(), 8); - - const std::vector& line_2 = lines_and_words[2]; - const std::vector& line_3 = lines_and_words[3]; - - EXPECT_EQ(line_2[kInstructionCyclesIndex], std::to_string(dot_cycles)); - EXPECT_EQ(line_2[kInstructionNameIndex], '%' + dot_instruction->name()); - - EXPECT_EQ(line_3[kInstructionCyclesIndex], std::to_string(add_cycles)); - EXPECT_EQ(line_3[kInstructionNameIndex], '%' + add_instruction->name()); + EXPECT_THAT(execution_profile.ToString( + backend().default_stream_executor()->GetDeviceDescription()), + AllOf(ContainsRegex(StrCat(dot_cycles, R"(\b.*%)", + dot_instruction->name())), + ContainsRegex(StrCat(add_cycles, R"(\b.*%)", + add_instruction->name())))); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index bb4db89f0a242c66b40a5c6541a968cdeb5fb0be..17e3c405f1e5269ddf2f03c031a1137f9bb14fcc 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -322,11 +322,13 @@ class HloDotDumper { public: HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, const DebugOptions& debug_options, bool show_metadata, - const HloExecutionProfile* profile, NodeFilter filter) + bool show_backend_config, const HloExecutionProfile* profile, + NodeFilter filter) : computation_(computation), - label_(label.ToString()), + label_(std::string(label)), debug_options_(debug_options), show_metadata_(show_metadata), + show_backend_config_(show_backend_config), profile_(profile), filter_(std::move(filter)) {} @@ -365,6 +367,7 @@ class HloDotDumper { string GetInstructionNodeShape(const HloInstruction* instr); string GetInstructionNodeLabel(const HloInstruction* instr); string GetInstructionNodeMetadata(const HloInstruction* instr); + string GetInstructionNodeBackendConfig(const HloInstruction* instr); string GetInstructionNodeExtraInfo(const HloInstruction* instr); string GetInstructionNodeInlinedOperands(const HloInstruction* instr); void AddInstructionIncomingEdges(const HloInstruction* instr); @@ -393,6 +396,7 @@ class HloDotDumper { const string label_; // overall name for the graph const DebugOptions& debug_options_; const bool show_metadata_; + const bool show_backend_config_; const HloExecutionProfile* profile_; // may be null const NodeFilter filter_; @@ -611,6 +615,10 @@ tooltip = " "; if (!extra_info.empty()) { StrAppend(&subcomp_label, "
", extra_info); } + string node_backend_config = GetInstructionNodeBackendConfig(parent_instr); + if (!node_backend_config.empty()) { + StrAppend(&subcomp_label, "
", node_backend_config); + } bool highlight = filter_.Highlight(parent_instr); const char* fillcolor; @@ -765,6 +773,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string node_shape = GetInstructionNodeShape(instr); string node_label = GetInstructionNodeLabel(instr); string node_metadata = GetInstructionNodeMetadata(instr); + string node_backend_config = GetInstructionNodeBackendConfig(instr); string extra_info = GetInstructionNodeExtraInfo(instr); string inlined_constants = GetInstructionNodeInlinedOperands(instr); string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); @@ -782,8 +791,8 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { } // Build the text that will be displayed inside the node. string node_body = node_label; - for (const string& s : - {trivial_subcomputation, node_metadata, extra_info, inlined_constants}) { + for (const string& s : {trivial_subcomputation, node_metadata, + node_backend_config, extra_info, inlined_constants}) { if (!s.empty()) { StrAppend(&node_body, "
", s); } @@ -816,7 +825,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( *elem_count *= dim; } } - if (elem_count.has_value() && *elem_count <= 8) { + if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) { return Printf("%s (%s)", constant->literal().ToString(), ShapeUtil::HumanString(constant->shape())); } @@ -916,6 +925,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kDivide: case HloOpcode::kEq: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kGe: case HloOpcode::kGt: @@ -923,6 +933,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -1078,13 +1089,23 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { return Join(lines, "
"); } +string HloDotDumper::GetInstructionNodeBackendConfig( + const HloInstruction* instr) { + if (!show_backend_config_ || instr->backend_config().empty()) { + return ""; + } + + return StrCat("backend_config=\"", instr->backend_config(), "\""); +} + string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { std::vector lines; // Get the instruction's extra attributes excluding the names of its // subcomputations, since those are drawn explicitly in the graph. for (const auto& line : instr->ExtraAttributesToString( - HloPrintOptions().set_print_subcomputation_references(false))) { + HloPrintOptions().set_print_subcomputation_mode( + HloPrintOptions::PrintSubcomputationMode::kOff))) { lines.push_back(HtmlLikeStringSanitize(line)); } @@ -1404,7 +1425,7 @@ string ExportGraph(const string& graph, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile, - bool show_metadata) { + bool show_metadata, bool show_backend_config) { GraphRendererInterface::GraphKind graph_kind; string graph; if (debug_options.xla_hlo_dump_as_graphdef()) { @@ -1414,9 +1435,10 @@ string DumpGraph(const HloComputation& computation, const string& label, &graph)); graph_kind = GraphRendererInterface::TF_GRAPHDEF; } else { - graph = HloDotDumper(&computation, label, debug_options, show_metadata, - hlo_execution_profile, NodeFilter()) - .Dump(); + graph = + HloDotDumper(&computation, label, debug_options, show_metadata, + show_backend_config, hlo_execution_profile, NodeFilter()) + .Dump(); graph_kind = GraphRendererInterface::DOT_GRAPH; } @@ -1427,15 +1449,15 @@ string DumpGraph(const HloComputation& computation, const string& label, } string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata) { + bool show_metadata, bool show_backend_config) { auto debug_options = node.GetModule()->config().debug_options(); string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); NodeFilter filter = MakeNodeFilter(&node, radius); - string graph = - HloDotDumper(node.parent(), label, debug_options, show_metadata, - /*profile=*/nullptr, filter) - .Dump(); + string graph = HloDotDumper(node.parent(), label, debug_options, + show_metadata, show_backend_config, + /*profile=*/nullptr, filter) + .Dump(); return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 2704aae1e3ba7fb131bfcb1287d807d785fd9774..fc8e1468aca9c2edbc22c30a41a1be8b32a1feca 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -56,7 +56,7 @@ string MaybeDumpHloModule(const HloModule& module, const string& label, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile = nullptr, - bool show_metadata = false); + bool show_metadata = false, bool show_backend_config = false); // Like DumpGraph, but renders only nodes "near" the given node in the graph. // @@ -64,7 +64,8 @@ string DumpGraph(const HloComputation& computation, const string& label, // (roughly) corresponds to the max distance a node may be from the primary node // before it's omitted from the graph. string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata = false); + bool show_metadata = false, + bool show_backend_config = false); // Dumps the HloModule::ToString() as a file into the provided directory path // suffixed with the provided label. diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a714d0e114245021c28da26beae444dbd3d99bb5..db1c33e2f0dfa0599810ab2e8d32209e64c5c865 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -51,7 +51,7 @@ using ::tensorflow::strings::StrCat; /* static */ StatusOr> HloInstruction::CreateFromProto( - HloModule* module, const HloInstructionProto& proto, + const HloInstructionProto& proto, const tensorflow::gtl::FlatMap& instruction_map, const tensorflow::gtl::FlatMap& computation_map) { TF_RET_CHECK(!proto.opcode().empty()); @@ -109,6 +109,7 @@ StatusOr> HloInstruction::CreateFromProto( instruction->name_ = proto.name(); instruction->metadata_ = proto.metadata(); + instruction->set_backend_config(proto.backend_config()); if (proto.has_literal()) { TF_ASSIGN_OR_RETURN(instruction->literal_, Literal::CreateFromProto(proto.literal())); @@ -256,10 +257,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kCos: case HloOpcode::kClz: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: @@ -437,7 +440,7 @@ HloInstruction::CreateCrossReplicaSum( << "Outfeed shape " << shape << " must be compatible with operand shape " << operand->shape(); instruction->AppendOperand(operand); - instruction->outfeed_config_ = outfeed_config.ToString(); + instruction->outfeed_config_ = std::string(outfeed_config); instruction->outfeed_shape_ = shape; return instruction; } @@ -792,23 +795,11 @@ HloInstruction::CreateBroadcastSequence( return instruction; } -// We put the fusion kind into the instruction's name for transpose-dot fusions, -// since those fusions are really just describing a type of dot rather than -// generating a novel computation. -static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { - switch (fusion_kind) { - case HloInstruction::FusionKind::kTransposeDot: - return "dot_fusion"; - default: - return "fusion"; - } -} - /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); instruction->fusion_kind_ = fusion_kind; - instruction->name_ = FusionNodeName(fusion_kind); + instruction->name_ = "fusion"; instruction->set_parent(fused_root->parent()); instruction->set_metadata(fused_root->metadata()); instruction->CloneAndFuseInternal(fused_root); @@ -824,7 +815,7 @@ static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { instruction->AppendOperand(operand); } instruction->fusion_kind_ = fusion_kind; - instruction->name_ = FusionNodeName(fusion_kind); + instruction->name_ = "fusion"; instruction->called_computations_.push_back(fusion_computation); fusion_computation->SetFusionInstruction(instruction.get()); return instruction; @@ -1123,7 +1114,7 @@ RandomDistribution HloInstruction::random_distribution() const { return distribution_; } -bool HloInstruction::HasSideEffect() const { +bool HloInstruction::HasSideEffectNoRecurse() const { switch (opcode_) { case HloOpcode::kSend: case HloOpcode::kSendDone: @@ -1135,16 +1126,22 @@ bool HloInstruction::HasSideEffect() const { case HloOpcode::kTrace: case HloOpcode::kHostCompute: return true; - default: { - // Check if any of the called computations has a side effect. - for (const auto& computation : called_computations()) { - if (computation->HasSideEffect()) { - return true; - } - } + default: return false; + } +} + +bool HloInstruction::HasSideEffect() const { + if (HasSideEffectNoRecurse()) { + return true; + } + // Check if any of the called computations has a side effect. + for (const auto& computation : called_computations()) { + if (computation->HasSideEffect()) { + return true; } } + return false; } /* static */ std::unique_ptr HloInstruction::CreateCall( @@ -1167,7 +1164,7 @@ bool HloInstruction::HasSideEffect() const { for (auto operand : operands) { instruction->AppendOperand(operand); } - instruction->custom_call_target_ = custom_call_target.ToString(); + instruction->custom_call_target_ = std::string(custom_call_target); return instruction; } @@ -1179,7 +1176,7 @@ bool HloInstruction::HasSideEffect() const { for (auto operand : operands) { instruction->AppendOperand(operand); } - instruction->channel_name_ = channel_name.ToString(); + instruction->channel_name_ = std::string(channel_name); instruction->cost_estimate_ns_ = cost_estimate_ns; return instruction; } @@ -1231,12 +1228,15 @@ bool HloInstruction::HasSideEffect() const { std::unique_ptr HloInstruction::CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, - HloModule* module) const { + HloModule* module, CloneMap* clone_map) const { VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; for (const HloInstruction* new_operand : new_operands) { VLOG(3) << " %" << new_operand->name(); } + if (module == nullptr) { + module = GetModule(); + } std::unique_ptr clone; @@ -1253,10 +1253,12 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kFloor: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: @@ -1342,7 +1344,8 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( break; case HloOpcode::kFft: CHECK_EQ(new_operands.size(), 1); - return CreateFft(shape, new_operands[0], fft_type_, fft_length_); + clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_); + break; case HloOpcode::kCrossReplicaSum: clone = CreateCrossReplicaSum(shape, new_operands); break; @@ -1415,9 +1418,15 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kConstant: clone = CreateConstant(literal_->CloneToUnique()); break; - case HloOpcode::kFusion: - clone = CloneFusionWithNewOperands(shape, new_operands, module); + case HloOpcode::kFusion: { + CHECK_NE(module, nullptr); + auto new_fused_computation = module->AddEmbeddedComputation( + fused_instructions_computation()->Clone("clone", module, clone_map)); + clone = CreateFusion(/*shape=*/shape, /*fusion_kind=*/fusion_kind(), + /*operands=*/new_operands, + /*fusion_computation=*/new_fused_computation); break; + } case HloOpcode::kParameter: clone = CreateParameter(parameter_number_, shape, name_); break; @@ -1481,15 +1490,19 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( } SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); + clone->set_backend_config(backend_config()); + if (clone_map != nullptr) { + InsertOrDie(clone_map, this, clone.get()); + } return clone; } HloInstruction::~HloInstruction() {} -std::unique_ptr HloInstruction::Clone(const string& suffix, - HloModule* module) const { +std::unique_ptr HloInstruction::Clone( + const string& suffix, HloModule* module, CloneMap* clone_map) const { std::unique_ptr clone = - CloneWithNewOperands(shape_, operands_, module); + CloneWithNewOperands(shape_, operands_, module, clone_map); if (suffix.empty()) { clone->name_ = name(); } else { @@ -1526,71 +1539,6 @@ std::unique_ptr HloInstruction::Clone(const string& suffix, return clone; } -std::unique_ptr HloInstruction::CloneFusionWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module) const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(parent() != nullptr); - - auto new_instruction = - WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); - // Add the operands to our new fusion instruction. - for (HloInstruction* new_operand : operands) { - new_instruction->AppendOperand(new_operand); - } - // Clone all the fused instructions for the new fusion instruction. - HloInstructionMap old_to_new; - std::list> new_fused_instructions; - // Create the list of fused parameters by mapping through the cloned, - // fused instructions. - for (HloInstruction* old_fused_parameter : - fused_instructions_computation()->parameter_instructions()) { - new_fused_instructions.push_back( - old_fused_parameter->Clone("clone", module)); - HloInstruction* new_fusion_parameter = new_fused_instructions.back().get(); - InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter); - } - for (auto old_fused_instruction : - fused_instructions_computation()->MakeInstructionPostOrder()) { - if (old_fused_instruction->opcode() == HloOpcode::kParameter) { - FindOrDie(old_to_new, old_fused_instruction); - continue; - } - std::vector new_operands; - for (int64 operand_idx = 0; - operand_idx < old_fused_instruction->operand_count(); ++operand_idx) { - HloInstruction* old_operand = - old_fused_instruction->mutable_operand(operand_idx); - new_operands.push_back(FindOrDie(old_to_new, old_operand)); - } - new_fused_instructions.push_back( - old_fused_instruction->CloneWithNewOperands( - old_fused_instruction->shape(), new_operands, module)); - HloInstruction* new_fused_instruction = new_fused_instructions.back().get(); - new_fused_instruction->set_parent(parent_); - InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction); - } - new_instruction->fusion_kind_ = fusion_kind_; - auto computation_builder = HloComputation::Builder( - fused_instructions_computation()->name() + ".clone", - new_instruction.get()); - // We iterated the fusion instructions in reverse post order which means - // that we must reverse our new list of fusion instructions. - for (auto new_fused_instruction_iter = new_fused_instructions.rbegin(); - new_fused_instruction_iter != new_fused_instructions.rend(); - ++new_fused_instruction_iter) { - computation_builder.AddInstruction(std::move(*new_fused_instruction_iter)); - } - if (module == nullptr) { - module = GetModule(); - } - auto fused_root_ = fused_expression_root(); - new_instruction->called_computations_.push_back( - CHECK_NOTNULL(module)->AddEmbeddedComputation( - computation_builder.Build(FindOrDie(old_to_new, fused_root_)))); - return new_instruction; -} - std::pair HloInstruction::LatestNonGteAncestorAndIndex() const { const HloInstruction* hlo = this; @@ -1619,6 +1567,8 @@ const Literal& HloInstruction::literal() const { return *literal_; } +bool HloInstruction::HasLiteral() const { return literal_ != nullptr; } + bool HloInstruction::CanHaveDimensionsField() const { return (opcode() == HloOpcode::kReverse || opcode() == HloOpcode::kConcatenate || @@ -1739,26 +1689,30 @@ bool HloInstruction::HasConstantOperand() const { bool HloInstruction::IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations, - const std::function& eq_shapes) const { + eq_computations) const { // Perform opcode specific checks. switch (opcode()) { // The result of these instructions only depend upon their opcode and // operands. case HloOpcode::kAbs: case HloOpcode::kAtan2: - case HloOpcode::kRoundNearestAfz: case HloOpcode::kAdd: + case HloOpcode::kBitcast: + case HloOpcode::kBitcastConvert: case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kClz: case HloOpcode::kComplex: + case HloOpcode::kConvert: case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kEq: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kGe: case HloOpcode::kGt: @@ -1766,6 +1720,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kAnd: case HloOpcode::kNot: case HloOpcode::kOr: @@ -1778,6 +1733,8 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kPower: case HloOpcode::kReal: case HloOpcode::kRemainder: + case HloOpcode::kReshape: + case HloOpcode::kRoundNearestAfz: case HloOpcode::kSelect: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: @@ -1789,6 +1746,12 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kTuple: return true; + // Broadcast, Concatenate, and Transpose need the same dimensions field. + case HloOpcode::kBroadcast: + case HloOpcode::kConcatenate: + case HloOpcode::kTranspose: + return dimensions() == other.dimensions(); + case HloOpcode::kFusion: return fusion_kind() == other.fusion_kind() && eq_computations(fused_instructions_computation(), @@ -1801,10 +1764,7 @@ bool HloInstruction::IdenticalSlowPath( return false; case HloOpcode::kParameter: - return parameter_number() == other.parameter_number() && - // Check the shape too because `this` and `other` may be in - // different HloComputations. - eq_shapes(shape(), other.shape()); + return parameter_number() == other.parameter_number(); case HloOpcode::kBatchNormTraining: case HloOpcode::kBatchNormInference: @@ -1816,12 +1776,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kConstant: return literal() == other.literal(); - // A convert result is determined by the primitive type that the operand is - // converted into. - case HloOpcode::kConvert: - case HloOpcode::kBitcastConvert: - return shape().element_type() == other.shape().element_type(); - // A reduce-precision operation is determined by the bit sizes. case HloOpcode::kReducePrecision: return exponent_bits() == other.exponent_bits() && @@ -1864,22 +1818,8 @@ bool HloInstruction::IdenticalSlowPath( eq_computations(scatter(), other.scatter()) && protobuf_util::ProtobufEquals(window(), other.window()); - case HloOpcode::kReshape: - return eq_shapes(shape(), other.shape()); - - // Transpose result is determined by the final shape and the permutation. - case HloOpcode::kTranspose: - return eq_shapes(shape(), other.shape()) && - dimensions() == other.dimensions(); // Remaining instructions with special values. - case HloOpcode::kBitcast: - return eq_shapes(shape(), other.shape()); - case HloOpcode::kBroadcast: - return eq_shapes(shape(), other.shape()) && - dimensions() == other.dimensions(); - case HloOpcode::kConcatenate: - return dimensions() == other.dimensions(); case HloOpcode::kGetTupleElement: return tuple_index() == other.tuple_index(); case HloOpcode::kPad: @@ -1889,11 +1829,6 @@ bool HloInstruction::IdenticalSlowPath( return slice_starts_ == other.slice_starts_ && slice_limits_ == other.slice_limits_ && slice_strides_ == other.slice_strides_; - case HloOpcode::kDynamicSlice: - return eq_shapes(shape(), other.shape()) && - dynamic_slice_sizes_ == other.dynamic_slice_sizes_; - case HloOpcode::kDynamicUpdateSlice: - return eq_shapes(shape(), other.shape()); case HloOpcode::kCall: case HloOpcode::kMap: return eq_computations(to_apply(), other.to_apply()); @@ -2160,28 +2095,68 @@ string PrintName(const string& name, const HloPrintOptions& options) { } // namespace string HloInstruction::ToString(const HloPrintOptions& options) const { - string result = - StrCat(PrintName(name(), options), " = ", - ShapeUtil::HumanStringWithLayout(shape()), " ", - HloOpcodeString(opcode()), "(", OperandsToString(options), ")"); + CanonicalNameMap new_map; + return ToStringWithCanonicalNameMap(options, &new_map); +} + +string HloInstruction::ToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { + string result = ""; + + // Logic to print the instruction name (e.g. "%foo = "). + if (options.canonicalize_instruction_names()) { + if (options.is_in_nested_computation()) { + // If we are canonicalizing instruction names and this is a top-level + // HloInstruction::ToString() call, don't print an instruction name. + StrAppend(&result, + PrintName(canonical_name_map->LookupOrInsert(name()), options), + " = "); + } + } else { + StrAppend(&result, PrintName(name(), options), " = "); + } + + // Print opcode, operand(s) and shape. + StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()), " ", + HloOpcodeString(opcode()), "(", + OperandsToStringWithCanonicalNameMap(options, canonical_name_map), + ")"); + + // Print additional attributes. If an instruction contains a subcomputation, + // the subcomputation is also printed here. for (const string& extra : ExtraAttributesToString(options)) { StrAppend(&result, ", ", extra); } + if (options.print_metadata() && (!metadata_.op_type().empty() || !metadata_.op_name().empty() || !metadata_.source_file().empty())) { StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); } + if (options.print_backend_config() && !backend_config().empty()) { + StrAppend(&result, ", backend_config=\"", CEscape(backend_config()), "\""); + } return result; } string HloInstruction::OperandsToString(const HloPrintOptions& options) const { + CanonicalNameMap new_map; + return OperandsToStringWithCanonicalNameMap(options, &new_map); +} + +string HloInstruction::OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { string operands; if (opcode() == HloOpcode::kConstant) { // For constants, show the actual value in place of an empty operand list. - if ((!ShapeUtil::IsTuple(shape()) && - ShapeUtil::ElementsIn(shape()) <= 10) || - options.print_large_constants()) { + // + // In HloInstruction, sometimes a constant literal is not constructed due + // to its size. Skip the printing in this case. + if (HasLiteral() && ((!ShapeUtil::IsTuple(shape()) && + ShapeUtil::ElementsIn(shape()) <= 10) || + options.print_large_constants())) { // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. string tmp = literal().ToString(); @@ -2215,7 +2190,14 @@ string HloInstruction::OperandsToString(const HloPrintOptions& options) const { if (options.print_operand_shape()) { str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); } - if (!options.compact_operands()) { + + // In a top-level HloInstruction::ToString() call, the operand name is not + // part of the canonical string. + if (options.canonicalize_instruction_names() && + options.is_in_nested_computation()) { + str.push_back(PrintName( + canonical_name_map->LookupOrInsert(operand->name()), options)); + } else if (!options.compact_operands()) { str.push_back(PrintName(operand->name(), options)); } StrAppend(out, Join(str, " ")); @@ -2284,7 +2266,8 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}")); } - if (options.print_subcomputation_references()) { + if (options.print_subcomputation_mode() == + HloPrintOptions::PrintSubcomputationMode::kNameOnly) { if (opcode() == HloOpcode::kWhile) { extra.push_back( StrCat("condition=", PrintName(while_condition()->name(), options))); @@ -2312,8 +2295,45 @@ std::vector HloInstruction::ExtraAttributesToString( PrintName(computation->name(), options)); }))); } + } else if (options.print_subcomputation_mode() == + HloPrintOptions::PrintSubcomputationMode::kFullBodies) { + HloPrintOptions new_options = options; + new_options.set_is_in_nested_computation(true); + switch (opcode()) { + case HloOpcode::kWhile: + extra.push_back( + StrCat("condition=\n", while_condition()->ToString(new_options))); + extra.push_back(StrCat("body=\n", while_body()->ToString(new_options))); + break; + case HloOpcode::kSelectAndScatter: + extra.push_back(StrCat("select=\n", select()->ToString(new_options))); + extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options))); + break; + case HloOpcode::kConditional: + extra.push_back(StrCat("true_computation=\n", + true_computation()->ToString(new_options))); + extra.push_back(StrCat("false_computation=\n", + false_computation()->ToString(new_options))); + break; + case HloOpcode::kCall: + case HloOpcode::kMap: + case HloOpcode::kReduceWindow: + case HloOpcode::kReduce: + extra.push_back( + StrCat("to_apply=\n", to_apply()->ToString(new_options))); + break; + default: + if (!called_computations().empty()) { + extra.push_back( + StrCat("calls=\n", + Join(called_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend(out, computation->ToString(new_options)); + }))); + } + break; + } } - if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv || opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) { extra.push_back(StrCat("channel_id=", channel_id_)); @@ -2351,12 +2371,13 @@ std::vector HloInstruction::ExtraAttributesToString( } // By contract, we print the custom call target even if - // !options.print_subcomputation_references(), because the call target is not + // options.print_subcomputation_mode() == kOff, because the call target is not // an HloComputation. if (opcode() == HloOpcode::kCustomCall) { extra.push_back( StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); } + return extra; } @@ -2386,6 +2407,7 @@ HloInstructionProto HloInstruction::ToProto() const { } *proto.mutable_metadata() = metadata_; + proto.set_backend_config(backend_config()); if (literal_ != nullptr) { *proto.mutable_literal() = literal_->ToProto(); } @@ -2451,6 +2473,10 @@ HloInstructionProto HloInstruction::ToProto() const { proto.add_fft_length(fft_len); } + if (has_sharding()) { + *proto.mutable_sharding() = sharding().ToProto(); + } + proto.set_channel_name(channel_name_); proto.set_cost_estimate_ns(cost_estimate_ns_); @@ -2487,8 +2513,6 @@ string HloInstruction::ToCategory() const { return "input fusion"; case FusionKind::kOutput: return "output fusion"; - case FusionKind::kTransposeDot: - return "dot"; case FusionKind::kCustom: return "custom fusion"; } @@ -2673,6 +2697,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleNegate(this); case HloOpcode::kExp: return visitor->HandleExp(this); + case HloOpcode::kExpm1: + return visitor->HandleExpm1(this); case HloOpcode::kFloor: return visitor->HandleFloor(this); case HloOpcode::kCeil: @@ -2681,6 +2707,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleClz(this); case HloOpcode::kLog: return visitor->HandleLog(this); + case HloOpcode::kLog1p: + return visitor->HandleLog1p(this); case HloOpcode::kTanh: return visitor->HandleTanh(this); case HloOpcode::kCos: @@ -2971,6 +2999,7 @@ Status HloInstruction::AcceptOrdered( continue; } + // TODO(b/78350259): Eliminate const laundering. HloInstruction* instruction = const_cast(const_instruction); @@ -3026,10 +3055,12 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: @@ -3094,7 +3125,7 @@ bool HloInstruction::IsElementwise() const { bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const { CHECK(IsElementwise()); - return !ShapeUtil::Equal(shape(), operand(operand_idx)->shape()); + return !ShapeUtil::SameDimensions(shape(), operand(operand_idx)->shape()); } namespace { @@ -3270,8 +3301,6 @@ string ToString(HloInstruction::FusionKind kind) { return "kInput"; case HloInstruction::FusionKind::kOutput: return "kOutput"; - case HloInstruction::FusionKind::kTransposeDot: - return "kTransposeDot"; case HloInstruction::FusionKind::kCustom: return "kCustom"; } @@ -3288,9 +3317,6 @@ StatusOr StringToFusionKind( if (kind_name == "kOutput") { return HloInstruction::FusionKind::kOutput; } - if (kind_name == "kTransposeDot") { - return HloInstruction::FusionKind::kTransposeDot; - } if (kind_name == "kCustom") { return HloInstruction::FusionKind::kCustom; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index a5e9aecb9e7f5204b53186abca78033215a75828..234dbc8399de2d88209dd8dd2be58dd152ddbe76 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -60,51 +60,75 @@ class HloModule; // A bunch of switches that control how the hlo text should be printed. class HloPrintOptions { public: + enum class PrintSubcomputationMode { + kOff, // Do not print anything about subcomputations. + kNameOnly, // Only print the name of subcomputations. + kFullBodies, // Print the full bodies of subcomputations. + }; + // Constructs the default print options: don't print large constants, don't // compact operands, no indentation. HloPrintOptions() : print_large_constants_(false), - print_subcomputation_references_(true), + print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly), print_metadata_(true), + print_backend_config_(true), compact_operands_(false), print_operand_shape_(true), print_program_shape_(true), print_percent_(true), - indent_amount_(0) {} + canonicalize_instruction_names_(false), + indent_amount_(0), + is_in_nested_computation_(false) {} static HloPrintOptions ShortParsable() { return HloPrintOptions() .set_print_large_constants(true) - .set_print_subcomputation_references(true) + .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly) .set_print_metadata(false) + .set_print_backend_config(false) .set_print_operand_shape(false) .set_print_program_shape(false) .set_print_percent(false); } + // Options to produce the canonical string representing an isomorphic + // computation graph. + static HloPrintOptions Canonical() { + return HloPrintOptions() + .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies) + .set_print_metadata(false) + .set_compact_operands(true) + .set_print_operand_shape(true) + .set_print_program_shape(false) + .set_print_percent(false) + .set_canonicalize_instruction_names(true); + } + // If true, large constants will be printed out. HloPrintOptions& set_print_large_constants(bool value) { print_large_constants_ = value; return *this; } - // If true, the names of subcomputations (e.g. a fusion node's fused - // computation) won't be printed. This makes the resulting text not parsable. - // - // A CustomCall's call target is printed even if - // print_subcomputation_references is false, because the call target isn't an - // HloComputation. - HloPrintOptions& set_print_subcomputation_references(bool value) { - print_subcomputation_references_ = value; + HloPrintOptions& set_print_subcomputation_mode( + PrintSubcomputationMode value) { + print_subcomputation_mode_ = value; return *this; } - // If true, metatdata will be printed. + // If true, metadata will be printed. HloPrintOptions& set_print_metadata(bool value) { print_metadata_ = value; return *this; } + // If true, backend_config will be printed. + HloPrintOptions& set_print_backend_config(bool value) { + print_backend_config_ = value; + return *this; + } + // If true, operands' shapes will be printed. HloPrintOptions& set_print_operand_shape(bool value) { print_operand_shape_ = value; @@ -130,54 +154,175 @@ class HloPrintOptions { return *this; } + // If true, canonicalizes instructions' name. Instead of using "%foo.1" as + // the name of an instruction, we use "%tmp_1", "%tmp_2" etc. + HloPrintOptions& set_canonicalize_instruction_names(bool value) { + canonicalize_instruction_names_ = value; + return *this; + } + // The indent of the hlo text block. HloPrintOptions& set_indent_amount(int value) { indent_amount_ = value; return *this; } + // If true, indicates the instruction being printed is inside a nested + // computation. + HloPrintOptions& set_is_in_nested_computation(bool value) { + is_in_nested_computation_ = value; + return *this; + } + bool print_large_constants() const { return print_large_constants_; } - bool print_subcomputation_references() const { - return print_subcomputation_references_; + PrintSubcomputationMode print_subcomputation_mode() const { + return print_subcomputation_mode_; } bool print_metadata() const { return print_metadata_; } + bool print_backend_config() const { return print_metadata_; } bool compact_operands() const { return compact_operands_; } bool print_operand_shape() const { return print_operand_shape_; } bool print_program_shape() const { return print_program_shape_; } bool print_percent() const { return print_percent_; } + bool canonicalize_instruction_names() const { + return canonicalize_instruction_names_; + } int indent_amount() const { return indent_amount_; } + int is_in_nested_computation() const { return is_in_nested_computation_; } private: bool print_large_constants_; - bool print_subcomputation_references_; + PrintSubcomputationMode print_subcomputation_mode_; bool print_metadata_; + bool print_backend_config_; bool compact_operands_; bool print_operand_shape_; bool print_program_shape_; bool print_percent_; + bool canonicalize_instruction_names_; int indent_amount_; + bool is_in_nested_computation_; +}; + +// For canonical string output, we need to have a canonical way to rename +// each instruction and its operands. Each operand is renamed as "tmp_", +// where is an index starting from 0. +class CanonicalNameMap { + public: + CanonicalNameMap() : index(0) {} + + string LookupOrInsert(const string& old_name) { + auto iter = canonical_name_map.find(old_name); + if (iter != canonical_name_map.end()) { + return iter->second; + } + + string new_name = tensorflow::strings::StrCat("tmp_", index++); + canonical_name_map[old_name] = new_name; + return new_name; + } + void Clear() { + canonical_name_map.clear(); + index = 0; + } + + private: + int64 index; + tensorflow::gtl::FlatMap canonical_name_map; }; -// HLO instructions are the IR used by the high-level compiler. +// HLO instructions are the atomic unit of the high-level compiler's IR. +// +// HloInstructions live inside of an HloComputation, which is analogous to a +// function in other programming languages. Nodes have no total order within +// their computation. Instead, they have a partial ordering determined by their +// data and control dependencies. +// +// HLO does not have basic blocks or explicit "branch" instructions. Instead, +// certain HloInstructions -- namely, kWhile, kConditional, and kCall -- encode +// control flow. For example, the kConditional HLO executes one of two possible +// computations, depending on the runtime value of a predicate. +// +// HLO is pure (mostly). It has no concept of mutable state. Instead, data +// values are produced by one HLO and flow into consumers across dependency +// edges. class HloInstruction { public: + // A fusion node computes the same value a call to its fusion computation + // would compute. However, the choice of fusion kind dictates codegen + // strategy for the backend. + // + // To generate code for a kFusion HloInstruction, most backends do something + // like the following: + // + // 1) Identify the "primary" HloInstruction of the fused computation. + // 2) Emit code that does the work of the primary node, creating its inputs + // and transforming its outputs as specified by the fused computation. + // + // In step (2), the code emitted is usually similar to the code that would be + // emitted for an *unfused* version of the primary node, except that + // + // - when the primary node reads an element of one of its operands, instead + // of loading the value from memory, it *computes* the value based on the + // contents of the fused computation. + // - when the primary node outputs a value, instead of storing it to memory, + // it forwards the value to its users, which then perform additional + // computations before the value is finally stored to memory at the root of + // the fusion node. + // + // An HloInstruction's FusionKind helps us find the kFusion instruction's + // primary node, and can also affect how we generate code in step (2). + // + // - kInput: The primary node is the root of the fused instruction. + // + // - kOutput: The primary node is not the root of the fused instruction. + // This fusion kind requires that one operand buffer of the fusion + // instruction be able to alias the output buffer. This constraint is + // usually enough to let backends find the primary node unambiguously. + // + // - kLoop: The primary node is the root of the fused computation, but, + // unlike in input fusion, we prescribe a specific implementation for + // codegen. Rather than generating code that looks like the code we'd emit + // for an unfused version of the primary/root node, we emit code that + // generates one element of the root at a time. + // + // - kCustom: Custom category for backend-specific fusions that don't fit + // into the above patterns. + // + // Not all backends support all fusion kinds, and given a particular fused + // computation, it's not in general safe to change its fusion kind. Creation + // of fusion nodes is always backend-specific. + // + // For elementwise ops (e.g. kAdd), most backends would emit a + // one-element-at-a-time implementation for the unfused version, so loop + // fusion and input fusion are probably equivalent if the root node is + // elementwise. They're not necessarily equivalent e.g. for kReduce, where an + // implementation might emit something more sophisticated for an unfused or + // input-fusion reduce, but will emit the naive code that reduces one element + // at a time for loop fusion with a reduce as the root. + // + // Another way to think of loop fusion is that it's equivalent to input + // fusion, but where the root node is an implicit identity node, whose + // unfused implementation is "read one element, write one element". + // + // TODO(b/79869434): This categorization scheme is not great. For one thing, + // input and loop fusion are basically the same thing: There is no reason for + // the HLO to encode backend-specific decisions about how e.g. a reduce that's + // the root of a fusion should be lowered. In addition, this scheme as + // written doesn't work for multi-output fusion, where the primary node is + // never actually the root (which is a kTuple instruction that gathers the + // multiple outputs of the fusion). enum class FusionKind { - kLoop, // Fused into a loop. - kInput, // Op's input is fused into the op itself. - kOutput, // Op's output is fused into the op itself. - // REQUIRES: At least one operand buffer must be able - // to alias the output buffer. - kTransposeDot, // Fused into a dot with transposed operands. - kCustom, // Custom category for backend-specific fusions that - // do not match any of the more specific ones. + kLoop, + kInput, + kOutput, + kCustom, }; ~HloInstruction(); // Creates an instruction from the given proto. Arguments: // - // module: the module which will contain the instruction. The newly created - // instruction is *not* added to the module or any computation, however. // proto: the proto to convert from. // instruction_map: a map from instruction id to HloInstruction*. This map // must contain all operands of the newly constructed instruction. @@ -185,7 +330,7 @@ class HloInstruction { // must contain all computations which the newly constructed instruction // calls. static StatusOr> CreateFromProto( - HloModule* module, const HloInstructionProto& proto, + const HloInstructionProto& proto, const tensorflow::gtl::FlatMap& instruction_map, const tensorflow::gtl::FlatMap& computation_map); @@ -503,6 +648,10 @@ class HloInstruction { // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } + // Returns true if this instruction has a side effect, irrespective of whether + // any called computations may contain an instruction with side effects. + bool HasSideEffectNoRecurse() const; + // Returns true if this instruction has a side effect. An instruction has a // side effect if it uses certain opcodes or calls a computation with a side // effect. @@ -597,10 +746,8 @@ class HloInstruction { if (opcode() != other.opcode()) { return false; } - using EqShapeFuncType = bool (*)(const Shape&, const Shape&); - EqShapeFuncType eq_shapes = - layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible; - if (!eq_shapes(shape(), other.shape())) { + if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape()) + : ShapeUtil::Compatible(shape(), other.shape()))) { return false; } if (operands().size() != other.operands().size()) { @@ -615,7 +762,7 @@ class HloInstruction { } } - return IdenticalSlowPath(other, eq_computations, eq_shapes); + return IdenticalSlowPath(other, eq_computations); } // Returns whether the instruction has a constant operand. @@ -643,6 +790,8 @@ class HloInstruction { // Detaches an instruction from its operands. That is, remove the instruction // from each operand's user set. This should only be called prior to // deallocating the instruction. + // + // TODO(b/78305363): Make this automatic when deleting an instruction. void DetachFromOperands(); // Performs a postorder DFS visit using this node as the root. If @@ -695,6 +844,9 @@ class HloInstruction { // Note: only constant and parameter opcodes have an associated literal. const Literal& literal() const; + // Returns whether there is literal associated with this instruction. + bool HasLiteral() const; + // Returns the parameter number associated with this instruction. // // Note: only parameter opcodes have an associated parameter number. @@ -956,6 +1108,14 @@ class HloInstruction { void clear_sharding() { sharding_ = nullptr; } // Return true if this operator has a sharding assigned. bool has_sharding() const { return sharding_ != nullptr; } + // Checks whether the instruction has compatible sharding with the other + // instruction. + bool has_compatible_sharding(const HloInstruction* other) const { + if (!has_sharding()) { + return !other->has_sharding(); + } + return other->has_sharding() ? sharding() == other->sharding() : false; + } // When creating a new instruction which either replaces, or shifts up (kCopy // insertion case), another instruction, we need to make sure the certain @@ -1157,23 +1317,30 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kRng RandomDistribution random_distribution() const; + // See documentation for Clone(). + using CloneMap = std::unordered_map; + // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of - // the instruction to form the name of the cloned instruction. If the module - // pointer is not nullptr, it will be the module where the cloned computations - // will be added to (in order to support deep cloning). Ignores the control - // predecessors and successors of this HLO instruction. + // the instruction to form the name of the cloned instruction. Ignores the + // control predecessors and successors of this HLO instruction. + // + // If the module pointer is not nullptr, then any cloned computations will be + // added to this module in order to support deep cloning. Otherwise the module + // of the instruction is used. + // + // If clone_map is not nullptr, then each original instruction that is cloned + // will be inserted and map to its clone. clone_map should not already contain + // any of the instructions to clone. std::unique_ptr Clone(const string& suffix = "clone", - HloModule* module = nullptr) const; + HloModule* module = nullptr, + CloneMap* clone_map = nullptr) const; - // Clones the HLO instruction as above but with new shape and operands. If - // the module pointer is not nullptr, it will be the module where the cloned - // computations will be added to (in order to support deep cloning). Ignores - // the control predecessors and successors of this HLO instruction. + // Clones the HLO instruction as above but with new shape and operands. std::unique_ptr CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module = nullptr) const; + HloModule* module = nullptr, CloneMap* clone_map = nullptr) const; // Returns the computations this instruction directly calls (if any). const std::vector& called_computations() const { @@ -1245,7 +1412,7 @@ class HloInstruction { // Gets/sets the string identifier for this instruction. const string& name() const { return name_; } - void set_name(tensorflow::StringPiece name) { name_ = name.ToString(); } + void set_name(tensorflow::StringPiece name) { name_ = std::string(name); } // Use the given NameUniquer to select a unique name for the instruction based // on the instruction's existing name. @@ -1262,6 +1429,19 @@ class HloInstruction { // if no id has been assigned yet). int unique_id() const { return unique_id_; } + // Returns the backend-specific configuration for how a backend should compile + // this HLO. The meaning of the field is backend specific. Not for use before + // or during general HLO optimization, since HLO optimizations do not preserve + // this field and they cannot interpret it due to its meaning being backend + // specific. + // + // TODO(b/78194644): Introduce structured configuration format as per + // go/xla-heuristics. + const string& backend_config() const { return backend_config_; } + void set_backend_config(string backend_config) { + backend_config_ = std::move(backend_config); + } + // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } const OpMetadata& metadata() const { return metadata_; } @@ -1283,6 +1463,7 @@ class HloInstruction { // Get/Set the number of partitions per outer dimension (in order, starting // with outer-most dimension first). Currently used by the parallel cpu // backend to partition HLOs into parallel tasks. + // // TODO(b/62783254) Replace these methods with a more general way to // annotate HLOs with backend-specific information. const std::vector& outer_dimension_partitions() const { @@ -1298,20 +1479,34 @@ class HloInstruction { const ShapeIndex& shape_index = {}); private: + // Prints an instruction to a string. + // + // The canonical string representation needs to name operands and instruction + // names in a consistent way. This is implemented through the + // canonical_name_map. + string ToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const; + + // Prints an operand to a string. + string OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const; + + // Allow HloInstruction to access the ToStringWithCanonicalNameMap() and + // OperandsToStringWithCanonicalNameMap() functions. + friend class HloComputation; + enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; // Helper class for computing OperandElementUse for kFusion. class FusionReusesParamElements; // See comments on Identical(). - // eq_shapes() is used to check shapes for equality, and would normally be - // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on - // whether we want a layout-sensitive check or not. bool IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations, - const std::function& eq_shapes) const; + eq_computations) const; // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( @@ -1510,6 +1705,10 @@ class HloInstruction { // The string representation of the infeed configuration. string infeed_config_; + // The backend-specific configuration for how a backend should compile this + // HLO. See the documentation on backend_config(). + string backend_config_; + // String identifier for instruction. string name_; @@ -1540,13 +1739,20 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); // an HloInstruction* or a const HloInstruction*. // To make the iteration order over the map deterministic, the comparator // should not be using the pointer values, but rather an intrinsic property of -// the hlo. +// the hlo. Exception: null pointer values compare less than non-null. // // Note that this cannot be used for HLO instructions across multiple modules // since the id of HLO instructions are only unique within each HLO module. struct HloPtrComparator { bool operator()(const HloInstruction* const& lhs, const HloInstruction* const& rhs) const { + if (rhs == nullptr) { + // Nothing compares less than nullptr. + return false; + } + if (lhs == nullptr) { + return true; + } return lhs->unique_id() < rhs->unique_id(); } }; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 5b65b1152c8298a8954890374626ae5329dccff9..a61c472c72804b077d21274d2e866a69c5e73157 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1102,7 +1102,7 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( - {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); + {dot, reshape}, HloInstruction::FusionKind::kLoop); auto fusion2 = fusion->Clone(); const HloInstruction* root = fusion->fused_expression_root(); @@ -1169,7 +1169,7 @@ TEST_F(HloInstructionTest, NestedFusionEquality) { auto computation = module->AddEntryComputation(builder.Build()); auto nested_fusion = computation->CreateFusionInstruction( - {dot, b_t}, HloInstruction::FusionKind::kTransposeDot); + {dot, b_t}, HloInstruction::FusionKind::kLoop); auto fusion = computation->CreateFusionInstruction( {add, nested_fusion}, HloInstruction::FusionKind::kOutput); @@ -1246,13 +1246,6 @@ TEST_F(HloInstructionTest, Stringification) { auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); - HloInstruction* fusion = computation->CreateFusionInstruction( - {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); - - EXPECT_EQ( - fusion->ToString(options), - "%dot_fusion = f32[5,20]{1,0} fusion(f32[5,10]{1,0} %x, " - "f32[20,10]{1,0} %y), kind=kTransposeDot, calls=%fused_computation"); HloInstruction* loop = builder.AddInstruction( HloInstruction::CreateWhile(sout, computation, computation, x)); @@ -1343,5 +1336,163 @@ TEST_F(HloInstructionTest, StringifyGather_1) { "index_vector_dim=2, window_bounds={30,29,28,27,26}"); } +TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { + // Tests stringification of a simple op, fusion, while, and conditional. + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + + auto options = HloPrintOptions().Canonical(); + + EXPECT_EQ(dot->ToString(options), + "f32[5,20]{1,0} dot(f32[5,10]{1,0}, f32[10,20]{1,0}), " + "lhs_contracting_dims={1}, rhs_contracting_dims={0}"); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + HloInstruction* fusion = computation->CreateFusionInstruction( + {dot, reshape}, HloInstruction::FusionKind::kLoop); + + EXPECT_EQ( + fusion->ToString(options), + R"(f32[5,20]{1,0} fusion(f32[5,10]{1,0}, f32[20,10]{1,0}), kind=kLoop, calls= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); +} + +TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { + // Tests stringification of a simple op, fusion, while, and conditional. + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({dot, reshape}, + HloInstruction::FusionKind::kLoop); + + HloInstruction* loop = builder.AddInstruction( + HloInstruction::CreateWhile(sout, computation, computation, x)); + + auto options = HloPrintOptions().Canonical(); + EXPECT_EQ(loop->ToString(options), + R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +}, body= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +})"); +} + +TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { + // Tests stringification of a simple op, fusion, while, and conditional. + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({dot, reshape}, + HloInstruction::FusionKind::kLoop); + + builder.AddInstruction( + HloInstruction::CreateWhile(sout, computation, computation, x)); + + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction* conditional = + builder.AddInstruction(HloInstruction::CreateConditional( + sout, pred, x, computation, x, computation)); + auto options = HloPrintOptions().Canonical(); + EXPECT_EQ( + conditional->ToString(options), + R"(f32[5,20]{1,0} conditional(pred[], f32[5,10]{1,0}, f32[5,10]{1,0}), true_computation= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +}, false_computation= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +})"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc new file mode 100644 index 0000000000000000000000000000000000000000..43c41ece6efc4f9e8ca74f16e0f63d29abc4de4e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -0,0 +1,306 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +using Worklist = std::deque; +using Workset = std::unordered_set; + +namespace { + +void AddToWorklist(const HloInstruction* instruction, Worklist* worklist, + Workset* workset) { + if (workset->count(instruction) == 0) { + worklist->push_back(instruction); + workset->insert(instruction); + VLOG(3) << "ADD instruction: " << instruction->name(); + } +} + +using VisitorFunction = std::function; + +void ForEachLiveIndex(const ShapeTree& index_tree, + const VisitorFunction& func) { + index_tree.ForEachElement([&](const ShapeIndex& shape_index, bool live) { + if (live) { + func(shape_index); + } + }); +} + +// Marks 'instruction' output live at 'shape_index'. +// Adds to 'worklist' iff: +// *) 'instruction' is not already on worklist. +// *) 'shape_index' has not yet been visited. +void MarkLiveAtIndex(const HloInstruction* instruction, + const ShapeIndex& shape_index, + HloLivenessAnalysis::HloIndexMap* live_index_map, + Worklist* worklist, Workset* workset) { + auto it = live_index_map->find(instruction); + if (it == live_index_map->end()) { + auto it_added = live_index_map->emplace( + std::piecewise_construct, std::forward_as_tuple(instruction), + std::forward_as_tuple(instruction->shape(), /*init_value=*/false)); + it = it_added.first; + } + if (it->second.element(shape_index) == false) { + AddToWorklist(instruction, worklist, workset); + *it->second.mutable_element(shape_index) = true; + VLOG(3) << "MARK instruction: " << instruction->name() + << " shape_index: " << shape_index.ToString(); + } +} + +// Marks 'instruction' live at all shape indices in its output. +void MarkLiveAtAllIndices(const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, + Worklist* worklist, Workset* workset) { + bool add_to_worklist = false; + auto it = live_index_map->find(instruction); + if (it == live_index_map->end()) { + live_index_map->emplace( + std::piecewise_construct, std::forward_as_tuple(instruction), + std::forward_as_tuple(instruction->shape(), /*init_value=*/true)); + add_to_worklist = true; + } else { + ShapeUtil::ForEachSubshape( + instruction->shape(), + [&](const Shape& sub_shape, const ShapeIndex& shape_index) { + if (it->second.element(shape_index) == false) { + add_to_worklist = true; + *it->second.mutable_element(shape_index) = true; + VLOG(3) << "MARK instruction: " << instruction->name() + << " shape_index: " << shape_index.ToString(); + } + }); + } + if (add_to_worklist) { + AddToWorklist(instruction, worklist, workset); + } +} + +// Propagates liveness through Tuple instructions. +// *) For each tuple operand: +// *) For tuple output shape index associated with operand: +// *) Propgate live shape indices to tuple operand at the associated +// shape index in the operands output, and add to worklist. +void PropagateLivenessThroughTuple( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset) { + CHECK_EQ(instruction->opcode(), HloOpcode::kTuple); + for (int64 operand_index = 0; operand_index < instruction->operand_count(); + ++operand_index) { + const ShapeTree& index_tree = FindOrDie(*live_index_map, instruction); + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + if (shape_index.empty() || shape_index[0] != operand_index) { + return; + } + // Mark top-level index of operand at 'operand_index'. + MarkLiveAtIndex(instruction->operand(operand_index), {}, live_index_map, + worklist, workset); + // Mark sub-shape index of operand at 'operand_index'. + ShapeIndex operand_shape_index; + for (int i = 1; i < shape_index.size(); ++i) { + operand_shape_index.push_back(shape_index[i]); + } + MarkLiveAtIndex(instruction->operand(operand_index), operand_shape_index, + live_index_map, worklist, workset); + }); + } +} + +// Propagates liveness through GetTupleElement instructions. +// *) For each live index in GetTupleElement output, mark output of GTE operand +// at associated shape index in its output, and add to worklist. +void PropagateLivenessThroughGTE( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset) { + CHECK_EQ(instruction->opcode(), HloOpcode::kGetTupleElement); + // Mark operand top-level index. + MarkLiveAtIndex(instruction->operand(0), {}, live_index_map, worklist, + workset); + const ShapeTree& index_tree = FindOrDie(*live_index_map, instruction); + // Propagate live shape indices along GTE -> Tuple edge. + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + ShapeIndex operand_shape_index(shape_index); + operand_shape_index.push_front(instruction->tuple_index()); + MarkLiveAtIndex(instruction->operand(0), operand_shape_index, + live_index_map, worklist, workset); + }); +} + +// Propagates liveness through While instructions. +// *) For each live index in While output, mark shape index of while.body.root +// and while.operand (adding each to worklist). +// *) Mark while.cond.root and add to worklist. +void PropagateLivenessThroughWhile( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset) { + CHECK_EQ(instruction->opcode(), HloOpcode::kWhile); + const ShapeTree& index_tree = FindOrDie(*live_index_map, instruction); + + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + // Propagate liveness to while body computation root instruction. + MarkLiveAtIndex(instruction->while_body()->root_instruction(), shape_index, + live_index_map, worklist, workset); + // Propagate liveness to tuple-shaped operand. + MarkLiveAtIndex(instruction->operand(0), shape_index, live_index_map, + worklist, workset); + }); + + // Propagate liveness to while condition computation root instruction. + MarkLiveAtIndex(instruction->while_condition()->root_instruction(), {}, + live_index_map, worklist, workset); +} + +// Propagates liveness out of Parameter instructions to callers and aliasing +// positions. This can occur if liveness propagates to a parameter in the +// while.condition computation, requiring liveness to propagate out to caller +// callsite while (and while.body.root). +void PropagateLivenessToParameterCallers( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset, CallGraph* call_graph) { + CHECK_EQ(instruction->opcode(), HloOpcode::kParameter); + const CallGraphNode& call_graph_node = + call_graph->GetNode(instruction->parent()); + if (call_graph_node.context() == CallContext::kSequential) { + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + if (callsite.instruction()->opcode() == HloOpcode::kWhile) { + auto* xla_while = callsite.instruction(); + const ShapeTree& index_tree = + FindOrDie(*live_index_map, instruction); + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + // Propagate liveness to while result{shape_index} + MarkLiveAtIndex(xla_while, shape_index, live_index_map, worklist, + workset); + // Propagate liveness to while body root{shape_index}. + MarkLiveAtIndex(xla_while->while_body()->root_instruction(), + shape_index, live_index_map, worklist, workset); + // Propagate liveness to operand(0){shape_index}. + MarkLiveAtIndex(xla_while->operand(0), shape_index, live_index_map, + worklist, workset); + }); + } + } + } +} + +} // namespace + +HloLivenessAnalysis::HloLivenessAnalysis(const HloModule& module) + : module_(module), call_graph_(CallGraph::Build(&module)) {} + +// Runs liveness analysis on 'module_'. +// Initializes worklist with entry root instruction (and any instruction with +// side-effects), marking all of their output shape indices live. +// Visits elements on worklist, propagating liveness from an instructions +// live output shape indices to its called computations and operands. +void HloLivenessAnalysis::RunAnalysis() { + Worklist worklist; + Workset workset; + // Add entry compuation root instruction. + MarkLiveAtAllIndices(module_.entry_computation()->root_instruction(), + &live_index_map_, &worklist, &workset); + for (auto* computation : module_.computations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->HasSideEffectNoRecurse()) { + // Add instructions with side effects. + MarkLiveAtAllIndices(instruction, &live_index_map_, &worklist, + &workset); + } + } + } + + while (!worklist.empty()) { + const HloInstruction* instruction = worklist.front(); + worklist.pop_front(); + workset.erase(workset.find(instruction)); + VLOG(1) << "VISIT instruction: " << instruction->name(); + + if (instruction->opcode() == HloOpcode::kTuple) { + PropagateLivenessThroughTuple(instruction, &live_index_map_, &worklist, + &workset); + } else if (instruction->opcode() == HloOpcode::kGetTupleElement) { + PropagateLivenessThroughGTE(instruction, &live_index_map_, &worklist, + &workset); + } else if (instruction->opcode() == HloOpcode::kWhile && + ShapeUtil::IsTuple(instruction->shape())) { + PropagateLivenessThroughWhile(instruction, &live_index_map_, &worklist, + &workset); + } else if (instruction->opcode() == HloOpcode::kParameter && + ShapeUtil::IsTuple(instruction->shape())) { + PropagateLivenessToParameterCallers(instruction, &live_index_map_, + &worklist, &workset, + call_graph_.get()); + } else { + // Propagate liveness to called computations. + for (auto* called_computation : instruction->called_computations()) { + MarkLiveAtAllIndices(called_computation->root_instruction(), + &live_index_map_, &worklist, &workset); + } + // Propagate liveness to operands. + for (HloInstruction* operand : instruction->operands()) { + MarkLiveAtAllIndices(operand, &live_index_map_, &worklist, &workset); + } + } + } +} + +bool HloLivenessAnalysis::IsLive(const HloInstruction* instruction, + const ShapeIndex& shape_index) const { + if (ContainsKey(live_index_map_, instruction)) { + return FindOrDie(live_index_map_, instruction).element(shape_index); + } + return false; +} + +/* static */ +StatusOr> HloLivenessAnalysis::Run( + const HloModule& module) { + VLOG(1) << "HloLivenessAnalysis::Run on module " << module.name(); + XLA_VLOG_LINES(2, module.ToString()); + + auto liveness_analysis = WrapUnique(new HloLivenessAnalysis(module)); + + liveness_analysis->RunAnalysis(); + + return std::move(liveness_analysis); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.h b/tensorflow/compiler/xla/service/hlo_liveness_analysis.h new file mode 100644 index 0000000000000000000000000000000000000000..fe55a8070a42a3d68836dd32cf7ce5823dd77951 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.h @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ + +#include + +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Analysis which identifies all live {HloInstruction, ShapeIndex} pairs in +// an HLO module. +// +// HloLivenessAnalysis marks the shape index of each live output of each +// instruction in the module, by propagating live shape index information +// from an instruction to its called computations and operands. +class HloLivenessAnalysis { + public: + // Maps from an HloInstruction to its live/dead output shape indices. + using HloIndexMap = + std::unordered_map>; + + // Runs liveness analysis on 'module'. Returns HloLivenessAnalysis object + // which exports liveness for each {HloInstruction, ShapeIndex} in 'module'. + static StatusOr> Run( + const HloModule& module); + + // Returns true if output of 'instruction' at 'shape_index' is live. + // Returns false otherwise. + bool IsLive(const HloInstruction* instruction, + const ShapeIndex& shape_index) const; + + private: + HloLivenessAnalysis(const HloModule& module); + + void RunAnalysis(); + + const HloModule& module_; + std::unique_ptr call_graph_; + HloIndexMap live_index_map_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8e2e2c7627ba6ac9e5078446056917a07436cbd7 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -0,0 +1,402 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class HloLivenessAnalysisTest : public HloTestBase { + protected: + HloLivenessAnalysisTest() {} + + // Run liveness analysis on the member module. For convenience returns a + // reference to the generated analysis stored in analysis_. + const HloLivenessAnalysis& RunLiveness(HloModule* module) { + liveness_ = HloLivenessAnalysis::Run(*module).ConsumeValueOrDie(); + return *liveness_; + } + + HloInstruction* GetInstruction(HloModule* module, const string& name) { + HloInstruction* to_return = nullptr; + for (auto* comp : module->computations()) { + for (auto* inst : comp->instructions()) { + if (inst->name() == name) { + to_return = inst; + break; + } + } + } + return CHECK_NOTNULL(to_return); + } + + std::unique_ptr liveness_; +}; + +// Test that add instruction at entry root is live at all output shape indices. +TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) { + auto module = tools::Parse(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + ROOT add = s32[] add(constant.1, constant.2) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); +} + +// Test that a dead add instruction is marked as dead by analysis. +TEST_F(HloLivenessAnalysisTest, DeadAdd) { + auto module = tools::Parse(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + add.1 = s32[] add(constant.1, constant.2) + ROOT add.2 = s32[] add(constant.1, constant.2) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "add.1"), {})); +} + +// Test that all output shape indices of entry root tuple (and defining +// instruction in its output) are marked live. +TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) { + auto module = tools::Parse(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + ROOT tuple.1 = (s32[], s32[]) tuple(constant.1, constant.2) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); +} + +// Tests that all outputs of nested tuple and entry root (and defining +// instruction values appearing in its output) are marked live. +TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) { + auto module = tools::Parse(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(1) + constant.2 = s32[] constant(2) + constant.3 = s32[] constant(3) + tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3) + ROOT tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +// Tests that GTE at entry root of Tuple instruction only propgates liveness +// to the live elements in tuple. +TEST_F(HloLivenessAnalysisTest, GteOfTuple) { + auto module = tools::Parse(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + tuple.1 = (s32[], s32[]) tuple(constant.1, constant.2) + ROOT get-tuple-element.1 = s32[] get-tuple-element(tuple.1), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); +} + +// Tests that GTE at entry root of nested Tuple instruction only propgates +// liveness to the live elements in tuple. +TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) { + auto module = tools::Parse(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + constant.3 = s32[] constant(2) + tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3) + tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1) + ROOT get-tuple-element.1 = (s32[], s32[]) get-tuple-element(tuple.2), index=1 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {})); + EXPECT_TRUE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {0})); + EXPECT_TRUE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +// Tests that GTE of GTE (at entry root) of nested Tuple instruction only +// propgates liveness to the live elements in tuple. +TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) { + auto module = tools::Parse(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + constant.3 = s32[] constant(2) + tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3) + tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1) + get-tuple-element.1 = (s32[], s32[]) get-tuple-element(tuple.2), index=1 + ROOT get-tuple-element.2 = s32[] get-tuple-element(get-tuple-element.1), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.2"), {})); + + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {})); + EXPECT_TRUE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {0})); + EXPECT_FALSE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0})); + EXPECT_FALSE( + liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +// Test that live/dead while tuple elements are marked live/dead correctly. +TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) { + auto module = tools::Parse(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply.0 = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple.0 = (s32[], s32[3]{0}) tuple(add.0, multiply.0) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + while.0 = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while.0), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.4"), {})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {1})); + + // While operand. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); + + // While body. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {})); +} + +// Tests that a tuple element live in while.cond computation, propagates +// liveness to while.body.root/while.result/while.operand (where it is unused). +TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) { + auto module = tools::Parse(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply.0 = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple.0 = (s32[], s32[3]{0}) tuple(add.0, multiply.0) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=1 + add.1 = s32[] add(get-tuple-element.3, get-tuple-element.4) + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(add.1, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + while.0 = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.5 = s32[] get-tuple-element(while.0), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.5"), {})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {1})); + + // While operand. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.4"), {})); + + // While body. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {})); +} + +// Tests that a use of while.result{0} propagates liveness to +// while.body.param{1} to while.body.root{1}, and then to while.body.param{2}. +TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) { + auto module = tools::Parse(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + get-tuple-element.2 = s32[] get-tuple-element(loop_var.1), index=1 + add.1 = s32[] add(get-tuple-element.1, get-tuple-element.2) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.1), index=2 + multiply.1 = s32[] multiply(get-tuple-element.3, get-tuple-element.3) + ROOT tuple.1 = (s32[], s32[], s32[]) tuple(add.1, get-tuple-element.3, multiply.1) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[], s32[]) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0 + constant.1 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.4, constant.1) + } + ENTRY SimpleLoop { + constant.2 = s32[] constant(0) + constant.3 = s32[] constant(1) + constant.4 = s32[] constant(2) + tuple.2 = (s32[], s32[], s32[]) tuple(constant.2, constant.3, constant.4) + while.1 = (s32[], s32[], s32[]) while(tuple.2), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.5 = s32[] get-tuple-element(while.1), index=0 + })") + .ValueOrDie(); + + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.5"), {})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {2})); + // While operand. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {2})); + // While body root. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {2})); + // While body param. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {2})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index 69deac263ee58f9e4d46987a54f09b11d650950a..7e4b8834357d39099f76450b849d6b5624e4e3b4 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -17,10 +17,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace testing { +using ::tensorflow::str_util::Join; + bool HloMatcher::MatchAndExplain( const HloInstruction* instruction, ::testing::MatchResultListener* listener) const { @@ -195,6 +198,41 @@ void HloShardingMatcher::DescribeTo(std::ostream* os) const { } } +bool HloDotWithContractingDimsMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!HloMatcher::MatchAndExplain(instruction, listener)) { + return false; + } + + const DotDimensionNumbers& dim_nums = instruction->dot_dimension_numbers(); + if (dim_nums.lhs_contracting_dimensions_size() != 1 || + dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) { + *listener << instruction->ToString() + << " has wrong lhs_contracting_dimensions (got {" + << Join(dim_nums.lhs_contracting_dimensions(), ",") << "} want {" + << lhs_contracting_dim_ << "})"; + return false; + } + + if (dim_nums.rhs_contracting_dimensions_size() != 1 || + dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) { + *listener << instruction->ToString() + << " has wrong rhs_contracting_dimensions (got {" + << Join(dim_nums.rhs_contracting_dimensions(), ",") << "} want {" + << rhs_contracting_dim_ << "})"; + return false; + } + + return true; +} + +void HloDotWithContractingDimsMatcher::DescribeTo(std::ostream* os) const { + HloMatcher::DescribeTo(os); + *os << " with lhs_contracting_dims={" << lhs_contracting_dim_ + << "} and rhs_contracting_dims={" << rhs_contracting_dim_ << "}"; +} + } // namespace testing void PrintTo(const HloInstruction* inst, ::std::ostream* os) { diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 5175736a2506c85836577a7f2ba2359a3d5a6b18..c33bdadf1c7145bf2aff09b01423c6c21382da0c 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -131,6 +131,27 @@ class HloShardingMatcher tensorflow::gtl::optional sharding_; }; +// Matches a Dot HLO instruction with specific LHS and RHS contracting +// dimensions. +class HloDotWithContractingDimsMatcher : public HloMatcher { + public: + explicit HloDotWithContractingDimsMatcher( + ::testing::Matcher lhs, + ::testing::Matcher rhs, int64 lhs_contracting_dim, + int64 rhs_contracting_dim) + : HloMatcher(HloOpcode::kDot, /*operands=*/{lhs, rhs}), + lhs_contracting_dim_(lhs_contracting_dim), + rhs_contracting_dim_(rhs_contracting_dim) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + + private: + int64 lhs_contracting_dim_; + int64 rhs_contracting_dim_; +}; + // HloInstruction* matchers for opcode and operands. Example: // namespace op = xla::opcode_matchers; // EXPECT_THAT(instruction, @@ -158,7 +179,6 @@ HLO_MATCHER(Convolution); HLO_MATCHER(Copy); HLO_MATCHER(CrossReplicaSum); HLO_MATCHER(Divide); -HLO_MATCHER(Dot); HLO_MATCHER(DynamicSlice); HLO_MATCHER(DynamicUpdateSlice); HLO_MATCHER(Eq); @@ -310,6 +330,30 @@ inline ::testing::Matcher NoSharding() { new ::xla::testing::HloShardingMatcher(tensorflow::gtl::nullopt)); } +inline ::testing::Matcher Dot( + ::testing::Matcher lhs_matcher, + ::testing::Matcher rhs_matcher) { + return ::testing::MakeMatcher(new ::xla::testing::HloMatcher( + ::xla::HloOpcode::kDot, {lhs_matcher, rhs_matcher})); +} + +// Matches a Dot HLO instruction if it has exactly one lhs contracting dimension +// equal to `lhs_contracting_dim` and exactly one rhs contracting dimension +// equal to `rhs_contracting_dim`. +// +// Currently the HLO verifier rejects Dot operations with more than one +// contracting dimension (even though we can represent these in the +// DotDimensionNumbers proto) so there is no need to generalize this to support +// multiple contracting dimensions. +inline ::testing::Matcher Dot( + ::testing::Matcher lhs_matcher, + ::testing::Matcher rhs_matcher, + int64 lhs_contracting_dim, int64 rhs_contracting_dim) { + return ::testing::MakeMatcher( + new ::xla::testing::HloDotWithContractingDimsMatcher( + lhs_matcher, rhs_matcher, lhs_contracting_dim, rhs_contracting_dim)); +} + #undef HLO_MATCHER } // namespace opcode_matchers diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index f2463060b7cd653dffb408f8df17f44fe0c1a97c..016cc01e33840aa195dfc0a21e8ac8f3d24a3e06 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace op = xla::testing::opcode_matchers; using ::testing::_; @@ -165,5 +166,41 @@ TEST(HloMatchersTest, ShardingMatcher) { "has incorrect sharding (expected: {maximal device=0})"); } +TEST(HloMatchersTest, DotMatcher) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion + +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[1,256] parameter(0) + arg1 = f32[256,1024] parameter(1) + ROOT dot = f32[1,1024] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + HloInstruction* root = module->entry_computation()->root_instruction(); + + EXPECT_THAT(root, op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/1, + /*rhs_contracting_dim=*/0)); + + EXPECT_THAT( + Explain(root, op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/0, + /*rhs_contracting_dim=*/0)), + "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " + "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong " + "lhs_contracting_dimensions (got {1} want {0})"); + + EXPECT_THAT( + Explain(root, op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/1, + /*rhs_contracting_dim=*/1)), + "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " + "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong " + "rhs_contracting_dimensions (got {0} want {1})"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index c7a719286753914ff39dc1fd528e74bc7fab2d7b..fbf1d58007e318a8a08aa9e11d9d54811533703e 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -46,6 +46,18 @@ HloModule::HloModule(const string& name, const HloModuleConfig& config) config_(config), unique_id_(next_unique_module_id_++) {} +StatusOr HloModule::LaunderConstInstructionFromModule( + const HloInstruction* hlo) { + if (hlo == nullptr) { + return nullptr; + } + + TF_RET_CHECK(hlo->GetModule() == this); + + // TODO(b/78350259): Eliminate const laundering. + return const_cast(hlo); +} + HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, bool uniquify_names) { @@ -254,24 +266,44 @@ StatusOr> HloModule::CreateFromProto( << ShapeUtil::HumanStringWithLayout(expected_program_shape.result()) << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape); - auto module = MakeUnique(proto.name(), entry_computation_handle, - module_config); - tensorflow::gtl::FlatMap computation_map; + tensorflow::gtl::FlatMap to_proto_id; + std::vector> computations; + HloComputation* entry = nullptr; for (const HloComputationProto& computation_proto : proto.computations()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr computation, - HloComputation::CreateFromProto( - module.get(), computation_proto, computation_map)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr computation, + HloComputation::CreateFromProto(computation_proto, computation_map)); CHECK_NE(computation.get(), nullptr); int64 computation_id = computation_proto.id(); TF_RET_CHECK(computation_id != -1); TF_RET_CHECK(!ContainsKey(computation_map, computation_id)); + computation_map[computation_id] = computation.get(); + to_proto_id[computation.get()] = computation_id; + if (computation_id == proto.entry_computation_id()) { + entry = computation.get(); + } + computations.push_back(std::move(computation)); + } + TF_RET_CHECK(entry != nullptr); + + auto module = MakeUnique(proto.name(), entry_computation_handle, + module_config); + + // Sort the computations in the proto id's order. + std::sort(computations.begin(), computations.end(), + [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); + + // Add sorted computations to the module. + for (auto& computation : computations) { + bool is_entry = computation.get() == entry; // Don't uniquify names because we want names to be stable across // serialization and deserialization. - computation_map[computation_id] = module->AddComputationInternal( - std::move(computation), - /*is_entry=*/proto.entry_computation_id() == computation_id, - /*uniquify_names=*/false); + module->AddComputationInternal(std::move(computation), is_entry, + /*uniquify_names=*/false); } TF_RET_CHECK(module->entry_computation_ != nullptr); diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index f9674df812dbbc9c5a99c7c57a18b800f23ee36f..02918c377776b73f2086fe41afc406567a12af4c 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -42,10 +42,18 @@ namespace xla { // Describes a compilation unit at the HLO level. // -// A HLO module contains one or more HLO computations. The module contains one -// "entry" computation which produces the result. The module also includes any -// embedded computations used by instructions such as "map" and "reduce". All -// computations are owned by the module. +// HloModule is the top-level unit in the HLO IR. It corresponds to a whole +// "program". Running a module, from beginning to end, is the only way to run +// an XLA program. +// +// A module contains one "entry computation"; this HloComputation is like main() +// in a C program. The result of running the module is the result of running +// this computation. +// +// A module also contains some number of "nested computations". Each nested +// computation is attached to an HloInstruction within some other computation. +// The meaning of the nested computation depends on the instruction it's +// attached to. class HloModule { public: HloModule(const string& name, @@ -217,6 +225,25 @@ class HloModule { // the lifetime of this process. int unique_id() const { return unique_id_; } + // Returns a non-const version of the passed-in const HloInstruction*. This is + // safe on the argument that if you have a non-const module, then you can + // access all instructions in the module as non-const. + // + // Returns an error if the passed-in instruction is not from this module, + // except that it is allowed to pass in a null pointer. + // + // TODO(b/78350259): Eliminate const laundering. The argument above is not + // reliable since at any time someone could add or discover a way for a + // non-const module to transitively contain a const HloInstruction. The + // reliable way to do this would be to create a const laundering map from a + // module, mapping each encountered HloInstruction to its non-const version + // and then look up each instruction in need of laundering in that map, but + // this is much more expensive and complicated. This returns a Status instead + // of doing a CHECK-failure in part to make it strongly apparent that this is + // something that can fail. + StatusOr LaunderConstInstructionFromModule( + const HloInstruction* hlo); + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc new file mode 100644 index 0000000000000000000000000000000000000000..98d20315e399c6b1a3979b5d11a89ef93869f4d9 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_dce.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +bool HasSendRecv(HloComputation* computation) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kSend || + instruction->opcode() == HloOpcode::kSendDone || + instruction->opcode() == HloOpcode::kRecv || + instruction->opcode() == HloOpcode::kRecvDone) { + return true; + } + for (auto* sub_computation : instruction->called_computations()) { + if (HasSendRecv(sub_computation)) { + return true; + } + } + } + return false; +} + +StatusOr RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { + bool changed = false; + for (auto* computation : module->computations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kWhile) { + continue; + } + + const auto* xla_while = instruction; + auto* while_body_comp = xla_while->while_body(); + auto* while_body_param = while_body_comp->parameter_instruction(0); + auto* while_body_root = while_body_comp->root_instruction(); + + if (!ShapeUtil::IsTuple(xla_while->shape()) || + while_body_root->opcode() != HloOpcode::kTuple || + HasSendRecv(while_body_comp)) { + // Only run DCE on tuple-shaped while loops where body root is Tuple, + // with no send/recv instructions. + VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString(); + continue; + } + + // Remove dead tuple elements. + const int64 tuple_element_count = + ShapeUtil::TupleElementCount(xla_while->shape()); + for (int64 i = 0; i < tuple_element_count; ++i) { + if (liveness->IsLive(xla_while, {i})) { + continue; + } + VLOG(1) << "WhileDCE Dead while tuple element." + << " while: " << xla_while->name() << " tuple_index: " << i; + // Transform while.body computation to make tuple element at + // 'shape_index' as simple pass-through parameter (which candidate + // be removed later by simplification pass). + HloInstruction* pass_thru_gte = while_body_comp->AddInstruction( + HloInstruction::CreateGetTupleElement( + while_body_param->shape().tuple_shapes(i), while_body_param, + i)); + // Replace while.body.root Tuple operand at 'tuple_index' with + // 'pass_thru_gte', making prior operand a dead root (to be cleaned + // up with a subsequent DCE pass). + TF_RETURN_IF_ERROR( + while_body_root->ReplaceOperandWith(i, pass_thru_gte)); + changed = true; + } + } + } + return changed; +} + +} // namespace + +StatusOr HloModuleDCE::Run(HloModule* module) { + VLOG(2) << "Before HloModuleDCE:"; + XLA_VLOG_LINES(3, module->ToString()); + + std::unique_ptr liveness; + TF_ASSIGN_OR_RETURN(liveness, HloLivenessAnalysis::Run(*module)); + + // Sweep through while instructions, transforming dead while tuple element + // computations to pass through tuple values (creating dead roots in while + // body computation in the process). + TF_ASSIGN_OR_RETURN(bool hlo_module_dce_changed, + RunWhileDCE(module, liveness.get())); + + // Run HloDCE to clean up any dead code created during HloModuleDCE. + HloDCE hlo_dce; + TF_ASSIGN_OR_RETURN(bool hlo_dce_changed, hlo_dce.Run(module)); + + VLOG(2) << "After HloModuleDCE:"; + XLA_VLOG_LINES(3, module->ToString()); + + return hlo_module_dce_changed | hlo_dce_changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h new file mode 100644 index 0000000000000000000000000000000000000000..29024085c1038961ef2b3721de1ce0e8a55ccf45 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_dce.h @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass which removes dead code from computations in the module using +// HloModule-scoped analysis (HloLivenessAnalysis). +// +// Sweeps through live instructions which cross computation boundaries (kWhile), +// and removes code at dead shape indices. +// +class HloModuleDCE : public HloPassInterface { + public: + ~HloModuleDCE() override {} + tensorflow::StringPiece name() const override { return "hlo-module-dce"; } + + // Run the pass on the given module. Returns whether the module was changed + // (instructions were removed). + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..53b7d0ed3964ca8a2c3bb73c62015a1c7dbfe487 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -0,0 +1,371 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_dce.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class HloModuleDceTest : public HloTestBase { + protected: + HloModuleDceTest() {} + + // Returns whether the given instruction exists in the given computation. + bool HasInstruction(const HloComputation& computation, + const HloInstruction* instruction) { + return std::find(computation.instructions().begin(), + computation.instructions().end(), + instruction) != computation.instructions().end(); + } + + // Returns whether the while instruction with name 'while_name' in + // 'computation' passes through its tuple element at 'tuple_index' from + // parameter to root instruction. + bool WhileBodyHasPassThroughTupleElement(const HloComputation* computation, + const string& while_name, + const int64 tuple_index) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kWhile && + instruction->name() == while_name) { + auto* while_body_comp = instruction->while_body(); + auto* while_body_param = while_body_comp->parameter_instruction(0); + auto* while_body_root = while_body_comp->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + return false; + } + auto* operand = while_body_root->operand(tuple_index); + if (operand->opcode() == HloOpcode::kGetTupleElement && + operand->tuple_index() == tuple_index && + operand->operand(0) == while_body_param) { + return true; + } + return false; + } + } + return false; + } +}; + +// Tests that a while with all outputs live is unmodified. +TEST_F(HloModuleDceTest, WhileWithLiveOutputs) { + auto module = tools::Parse(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests a while loop with one unused output (which is used in the while loop +// body by an instruction with side-effects: rng) is unmodified. +TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) { + auto module = tools::Parse(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], f32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = f32[] get-tuple-element(loop_var.1), index=1 + constant.2 = f32[] constant(1.0) + rng = f32[] rng(constant.2, get-tuple-element.2), distribution=rng_uniform + add.1 = s32[] add(get-tuple-element.2, constant.2) + ROOT tuple = (s32[], f32[]) tuple(add, add.1) + } + SimpleLoop.condition { + loop_var.2 = (s32[], f32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.3 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.3) + } + ENTRY SimpleLoop { + constant.4 = s32[] constant(0) + constant.5 = f32[] constant(0.0) + tuple.1 = (s32[], f32[]) tuple(constant.4, constant.5) + while = (s32[], f32[]) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests that a while loop with one dead tuple element at {1} has its while +// loop body modified to make that tuple element pass-through the while body. +TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) { + auto module = tools::Parse(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + while = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // While tuple element {1} should not be pass-through before ModuleDCE. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + // While tuple element {1} should now be pass-through after ModuleDCE. + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests that a tuple element {1} used by condition computation (which appears +// dead in while.body{1} and at while.result{1}) propgates liveness of this +// tuple element to while.body{1} and at while.result{1}. +TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) { + auto module = tools::Parse(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[] get-tuple-element(loop_var.1), index=1 + multiply = s32[] multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[]) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[] constant(0) + tuple.1 = (s32[], s32[]) tuple(constant.3, constant.4) + while = (s32[], s32[]) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // While tuple element {1} should not be pass-through before ModuleDCE. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + // While tuple element {1} still be pass-through after ModuleDCE. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests that HloModuleDCE can remove a dead tuple element at index {1} between +// two dependent while loops. +TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) { + auto module = tools::Parse(R"( + HloModule SimpleLoop + SimpleLoop.body0 { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition0 { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + SimpleLoop.body1 { + loop_var.3 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0 + constant.3 = s32[] constant(1) + add.1 = s32[] add(get-tuple-element.4, constant.3) + get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1 + multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5) + ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1) + } + SimpleLoop.condition1 { + loop_var.4 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0 + constant.4 = s32[] constant(5) + ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4) + } + ENTRY SimpleLoop { + constant.5 = s32[] constant(0) + constant.6 = s32[3]{0} constant({0, 1, 2}) + tuple.2 = (s32[], s32[3]{0}) tuple(constant.5, constant.6) + while.1 = (s32[], s32[3]{0}) while(tuple.2), condition= + SimpleLoop.condition0, body=SimpleLoop.body0 + get-tuple-element.7 = s32[] get-tuple-element(while.1), index=0 + tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6) + while.2 = (s32[], s32[3]{0}) while(tuple.3), condition= + SimpleLoop.condition1, body=SimpleLoop.body1 + ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // Before HloModuleDCE while.1 and while.2 should not have pass-thru elements. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 1)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + // After HloModuleDCE while.1 and while.2 should have pass-thru elements, + // after being modified to pass through unused tuple element {1}. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 0)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 1)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 0)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); +} + +// Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and +// while.2{1}, between two dependent while loops. +TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { + auto module = tools::Parse(R"( + HloModule SimpleLoop + SimpleLoop.body0 { + loop_var.1 = (s32[3]{0}, s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=1 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=0 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[3]{0}, s32[]) tuple(multiply, add) + } + SimpleLoop.condition0 { + loop_var.2 = (s32[3]{0}, s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + SimpleLoop.body1 { + loop_var.3 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0 + constant.3 = s32[] constant(1) + add.1 = s32[] add(get-tuple-element.4, constant.3) + get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1 + multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5) + ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1) + } + SimpleLoop.condition1 { + loop_var.4 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0 + constant.4 = s32[] constant(5) + ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4) + } + ENTRY SimpleLoop { + constant.5 = s32[] constant(0) + constant.6 = s32[3]{0} constant({0, 1, 2}) + tuple.2 = (s32[3]{0}, s32[]) tuple(constant.6, constant.5) + while.1 = (s32[3]{0}, s32[]) while(tuple.2), condition= + SimpleLoop.condition0, body=SimpleLoop.body0 + get-tuple-element.7 = s32[] get-tuple-element(while.1), index=1 + tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6) + while.2 = (s32[], s32[3]{0}) while(tuple.3), condition= + SimpleLoop.condition1, body=SimpleLoop.body1 + ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // Before HloModuleDCE while.1{0} and while.2{1} should not be pass-thru. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + // After HloModuleDCE while.1{0} and while.2{1} not be pass-thru elements. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 1)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 0)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 54c34ce116651608e6d91cdcba9c708ca3a5f75e..b4cd3c730e323b8459312edbebc564e08f9d6840 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h" +#include #include #include @@ -47,13 +48,16 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const { case ComputationKind::kConditionalFalse: repr += ":CONDITIONAL_FALSE"; break; + case ComputationKind::kCallFunction: + repr += ":CALL"; + break; } return repr; } /* static */ StatusOr> HloModuleGroupMetadata::Build(const std::vector& modules) { - auto metadata = absl::make_unique(modules); + auto metadata = MakeUnique(modules); TF_RETURN_IF_ERROR(metadata->Build()); return std::move(metadata); } @@ -107,6 +111,31 @@ Status HloModuleGroupMetadata::Build() { TF_RETURN_IF_ERROR(computation->Accept(visitor)); } } + TF_RETURN_IF_ERROR(VerifyCompanionSets()); + return Status::OK(); +} + +Status HloModuleGroupMetadata::VerifyCompanionSets() const { + // TODO(dlibenzi): Migrate this to use the device instead of module ID, once + // the kDomain CL goes in. + for (const auto& companions : companion_sets_) { + // A companion set must be composed at most of an instruction per + // device/module. + std::unordered_set devices; + for (HloInstruction* instruction : *companions) { + int64 device = GetModuleId(instruction->parent()->parent()); + if (!devices.insert(device).second) { + std::stringstream ss; + ss << "Companion set:" << std::endl; + for (HloInstruction* hlo : *companions) { + ss << " " << hlo->name() << " (" + << GetModuleId(hlo->parent()->parent()) << ")" << std::endl; + } + ss << "has multiple instructions on the same device"; + return FailedPrecondition("%s", ss.str().c_str()); + } + } + } return Status::OK(); } @@ -206,6 +235,9 @@ Status HloModuleGroupMetadata::RecordInstructions() { TrackedInstruction(hlo, ComputationKind::kConditionalTrue); tracked_instructions_[hlo->false_computation()] = TrackedInstruction(hlo, ComputationKind::kConditionalFalse); + } else if (hlo->opcode() == HloOpcode::kCall) { + tracked_instructions_[hlo->to_apply()] = + TrackedInstruction(hlo, ComputationKind::kCallFunction); } if (!IsChannelInstruction(hlo)) { return Status::OK(); @@ -258,14 +290,15 @@ Status HloModuleGroupMetadata::RecordInstructions() { Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, HloInstruction* instruction2) { TF_RET_CHECK(instruction1->opcode() == HloOpcode::kWhile || - instruction1->opcode() == HloOpcode::kConditional); + instruction1->opcode() == HloOpcode::kConditional || + instruction1->opcode() == HloOpcode::kCall); VLOG(2) << "adding as companions:" << instruction1->ToString() << " and " << instruction2->ToString(); if (!ContainsKey(companion_set_index_, instruction1) && !ContainsKey(companion_set_index_, instruction2)) { companion_sets_.push_back( - absl::make_unique>()); + tensorflow::MakeUnique>()); auto companion_set = companion_sets_.back().get(); companion_set->insert(instruction1); companion_set->insert(instruction2); @@ -336,21 +369,11 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { } } - // Check if channel instructions are used only in allowed computations. - const auto allowed = [this](HloInstruction* hlo) { - HloComputation* computation = hlo->parent(); - const HloModule* module = computation->parent(); - if (module->entry_computation() == computation || - tracked_instructions_.count(computation) > 0) { - return true; - } - return false; - }; for (const Channel& channel : channels_) { - if (!allowed(channel.send) || !allowed(channel.send_done) || - !allowed(channel.recv) || !allowed(channel.recv_done)) { - return FailedPrecondition("channel is used in disallowed computation"); - } + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send)); + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send_done)); + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv)); + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv_done)); } // Check if the nest levels match for each channel. for (const Channel& channel : channels_) { @@ -368,4 +391,15 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { return Status::OK(); } +Status HloModuleGroupMetadata::CheckCommunicatingInstruction( + HloInstruction* instruction) const { + HloComputation* computation = instruction->parent(); + const HloModule* module = computation->parent(); + if (module->entry_computation() == computation || + tracked_instructions_.count(computation) > 0) { + return Status::OK(); + } + return FailedPrecondition("channel is used in disallowed computation"); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index c48a7ab0b59269474f7406ef24a249355528e085..3ef4542f9129632de4975688ae7e9e2c5f43a7ee 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -60,6 +60,7 @@ class HloModuleGroupMetadata { kWhileBody, kConditionalTrue, kConditionalFalse, + kCallFunction, }; // Tracks the instruction mapped to a given computation, and the computation @@ -202,6 +203,15 @@ class HloModuleGroupMetadata { Status AddCompanion(HloInstruction* instruction1, HloInstruction* instruction2); + // Checks whether a communicating instruction is placed in a valid position + // within the graph. + Status CheckCommunicatingInstruction(HloInstruction* instruction) const; + + // Performs a consistency check on the companion sets built for the input + // modules. Check that a companion set does not include instructions from the + // same module/device. + Status VerifyCompanionSets() const; + // Retrieves a pointer to the stored TrackedInstruction associated with a // tracked computation, or nullptr in case such computation is not tracked. const TrackedInstruction* GetTrackedInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 289c96b0a7b90c5f8a122cd3fc327a5762099106..5a0d1e264eb5095ff53721416ebcf4842a063f97 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -289,7 +290,7 @@ HloModuleGroupUtil::ComputeReachability( TF_RETURN_IF_ERROR( VisitTopologicalOrder(&visit_states, visit_function, root)); } - auto reachability = absl::make_unique(post_order); + auto reachability = MakeUnique(post_order); for (HloInstruction* hlo : post_order) { reachability->SetReachabilityToUnion(GlobalPredecessors(hlo), hlo); } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index ca763076a16af1150a8623fb7dbf22c46a5ca263..ac7cd2f2f517cf8831416d9265fc48bbf9fce340 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -74,6 +74,7 @@ namespace xla { V(kDynamicUpdateSlice, "dynamic-update-slice") \ V(kEq, "equal-to", kHloOpcodeIsComparison) \ V(kExp, "exponential") \ + V(kExpm1, "exponential-minus-one") \ V(kFft, "fft") \ V(kFloor, "floor") \ V(kFusion, "fusion", kHloOpcodeIsVariadic) \ @@ -87,6 +88,7 @@ namespace xla { V(kIsFinite, "is-finite") \ V(kLe, "less-than-or-equal-to", kHloOpcodeIsComparison) \ V(kLog, "log") \ + V(kLog1p, "log-plus-one") \ V(kAnd, "and") \ V(kNot, "not") \ V(kOr, "or") \ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index e89d94bede6c437ca1131a1b1b0098390d58c0d9..dcd4725fe78e8b9b5d14437e964cb5aaf1664117 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -170,10 +169,10 @@ bool HloOrdering::UseIsBeforeValueDefinition( // is before the def if the instruction allows buffer sharing (in place // computation). if (use.instruction == value.defining_instruction() && - CanShareOperandBufferWithUser( + dataflow.CanShareOperandBufferWithUser( use.instruction->mutable_operand(use.operand_number), use.operand_index, value.defining_instruction(), - value.defining_index(), dataflow)) { + value.defining_index())) { VLOG(4) << " use is value def, and instruction can share use buffer"; return true; } diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 5120775737bfa32bbb656421216f2b3fbef590ea..d8f1ab916b5c5c500c2d8dcd8605be083f95862a 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -90,7 +90,7 @@ StatusOr HloPassPipeline::Run(HloModule* module) { return Status::OK(); }; - string prefix = name().ToString() + ": pipeline start"; + string prefix = std::string(name()) + ": pipeline start"; bool changed = false; string message; TF_RETURN_IF_ERROR( @@ -98,12 +98,12 @@ StatusOr HloPassPipeline::Run(HloModule* module) { const string xla_dump_per_pass_hlo_proto_to = module->config().debug_options().xla_dump_per_pass_hlo_proto_to(); if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, name().ToString(), - "pipeline_start"); + DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, + std::string(name()), "pipeline_start"); } for (auto& pass : passes_) { - if (disabled_passes.count(pass->name().ToString()) > 0) { + if (disabled_passes.count(std::string(pass->name())) > 0) { VLOG(1) << " Skipping HLO pass " << pass->name() << ", disabled by --xla_disable_hlo_passes"; continue; @@ -121,7 +121,7 @@ StatusOr HloPassPipeline::Run(HloModule* module) { run_invariant_checkers(StrCat("after running pass: ", pass->name()))); if (!xla_dump_per_pass_hlo_proto_to.empty()) { DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, - name().ToString(), pass->name().ToString()); + std::string(name()), std::string(pass->name())); } changed |= changed_this_pass; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index b171d41a31ed23f0886e7363289ea56c92216572..39b85de0f12024f5e20ddd37618987c6d06bc307 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -274,9 +273,8 @@ ItemList GetUsers(const InstructionList& instruction_list, for (const BufferAlias& buffer_alias : points_to_analysis.GetBufferAliases(*logical_buffer)) { for (const HloInstruction* user : buffer_alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(buffer_alias.instruction(), - buffer_alias.index(), user, - points_to_analysis)) { + if (points_to_analysis.DoesNotUseOperandBuffer( + buffer_alias.instruction(), buffer_alias.index(), user)) { // The alias may be an operand of 'user', but the LogicalBuffer cannot // possibly be used by the instruction so ignore 'user'. This is the // case, for example, for the tuple element buffers in a GetTupleElement diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 48da1a505c9bea72378aaba7824548cca0eef447..2a601ec3d183023954b6f1b6bca7594384378169 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -171,7 +170,7 @@ StatusOr>> HloRunner::ExecuteReplicated( int64 device = device_assignment(i, 0); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device)); - streams.push_back(absl::make_unique(executor)); + streams.push_back(MakeUnique(executor)); streams.back()->Init(); service_run_options.emplace_back(GetServiceRunOptionsForDevice( device, streams.back().get(), &device_assignment)); @@ -198,7 +197,7 @@ StatusOr>> HloRunner::ExecuteReplicated( num_threads += options.num_replicas; } if (num_threads > 0) { - pool = absl::make_unique( + pool = MakeUnique( tensorflow::Env::Default(), "infeed_outfeed", /*num_threads=*/num_threads); } @@ -229,7 +228,7 @@ StatusOr>> HloRunner::ExecuteReplicated( VLOG(1) << "Starting outfeed on device " << device; for (int64 step = 1; options.infeed_steps < 0 || step <= options.infeed_steps; ++step) { - auto literal = absl::make_unique(); + auto literal = MakeUnique(); TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( executor, options.outfeed_shape, literal.get())); if (options.outfeed_values != nullptr) { diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 1a767628f6e2d33df353366974fb866e89f0df5a..854aa943199397c0e3f84d48a74ef41ae0d3db56 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -62,7 +62,34 @@ StatusOr MinimumMemoryForSequence( namespace { // Class implementing a list scheduler of HLO instructions which produces a -// sequence which minimizes memory usage. +// sequence which minimizes memory usage by preferring to schedule the node that +// frees bigger buffer and defines smaller outputs. +// +// Note that list scheduler is a greedy algorithm which cannot guarantee a +// global optimal solution. As a counterexample, considering the following +// graph: +// +// +--> B ===> C -------+ +// A -> | | +// | v +// +--> D ---> F=======>G +// | ^ +// | | +// +--> E -----+ +// +// --> : Buffer with size 1 +// ==> : Buffer with size 2 +// +// The list scheduler will always try to defer scheduling B in a greedy way +// since its output buffer is bigger than input. The sequence it creates will +// be: +// A D E F B C G +// , which has a maximum memory usage of 6 (B is alive while F is executing). +// +// An optimal way to shedule the previous graph is: +// A B C D E F G +// , which has a maximum memory usage of 5 (when F is executing). +// class ListScheduler { public: // Construct and return a memory-minimizing sequence of HLO instructions @@ -70,8 +97,11 @@ class ListScheduler { static StatusOr> Run( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - ListScheduler scheduler(computation, points_to_analysis, size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + ListScheduler scheduler(computation, points_to_analysis, size_function, + memory_by_computation); return scheduler.CreateSchedule(); } @@ -92,10 +122,13 @@ class ListScheduler { ListScheduler(const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) : computation_(computation), points_to_analysis_(points_to_analysis), - size_function_(size_function) { + size_function_(size_function), + memory_by_computation_(memory_by_computation) { // Create a map containing the LogicalBuffer uses for each HLO // instruction. An HLO instruction "uses" a LogicalBuffer if the // LogicalBuffer is in an operand of the instruction as indicated by @@ -185,6 +218,12 @@ class ListScheduler { } // Returns the number of bytes freed if the HLO instruction is scheduled. + // If the instruction calls subcomputations, we count the memory used by the + // subcomputations as memory "defined" by the instruction. This is not + // entirely accurate, because subcomputation memory will be freed after the + // instruction finishes. But it is more accurate than not taking + // subcomputations into account at all. In the future, we may improve + // accounting for subcomputation memory (b/65409243). int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { int64 freed_bytes = 0; for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { @@ -194,7 +233,19 @@ class ListScheduler { freed_bytes += size_function_(*buffer); } } - return freed_bytes - entry.bytes_defined; + // We only count the memory usage of the largest subcomputation, instead of + // adding them all, because subcomputations won't execute in parallel. + int64 max_subcomputation_bytes = 0; + for (const auto* c : entry.instruction->called_computations()) { + auto it = memory_by_computation_.find(c); + if (it != memory_by_computation_.end()) { + int64 subcomputation_bytes = it->second; + if (subcomputation_bytes > max_subcomputation_bytes) { + max_subcomputation_bytes = subcomputation_bytes; + } + } + } + return freed_bytes - entry.bytes_defined - max_subcomputation_bytes; } // Constructs the scheduling priority of the given instruction. @@ -315,6 +366,11 @@ class ListScheduler { const HloComputation& computation_; const TuplePointsToAnalysis& points_to_analysis_; const LogicalBuffer::SizeFunction& size_function_; + // Computations are analyzed in post-order. When scheduling an instruction + // that includes subcomputations, such as a while loop, we use this map to + // look up the memory needed by subcomputations. + const tensorflow::gtl::FlatMap& + memory_by_computation_; // A map containing the LogicalBuffers that each instruction uses. tensorflow::gtl::FlatMap> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + VLOG(2) << "Computation: " << computation.name(); + if (algorithm) { + return algorithm(computation, points_to_analysis, size_function, + memory_by_computation); + } + return DefaultMemoryScheduler(computation, points_to_analysis, size_function, + memory_by_computation); +} + +} // namespace + StatusOr MinimumMemoryForComputation( const HloComputation& computation, const std::vector& sequence, @@ -352,30 +426,17 @@ StatusOr MinimumMemoryForComputation( return result.heap_size; } -StatusOr> CreateMemoryMinimizingSequence( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm) { - VLOG(2) << "Computation: " << computation.name(); - if (algorithm) { - return algorithm(computation, points_to_analysis, size_function); - } - return DefaultMemoryScheduler(computation, points_to_analysis, size_function); -} - -} // namespace - StatusOr> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { // This ordering is based on DFS post-order, with a heuristic to decide which // operand to visit first. The heuristic is based on 'extra_users', which is // simply users-1 for each instruction. By subtracting 1, we're saying that // instructions with no users or a single user don't count; instructions with // lots of fan-out will be visited earlier. - int64 cumulative_total_size = 0; tensorflow::gtl::FlatMap extra_users; tensorflow::gtl::FlatMap total_sizes; for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { @@ -388,14 +449,12 @@ StatusOr> DFSMemoryScheduler( int64 logical_buffer_size = SumLogicalBufferSizes( points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); total_sizes[hlo] = logical_buffer_size; - cumulative_total_size += logical_buffer_size; tensorflow::gtl::FlatSet unique_operands( hlo->operands().begin(), hlo->operands().end()); for (const HloInstruction* operand : unique_operands) { extra_users[hlo] += extra_users[operand]; total_sizes[hlo] += total_sizes[operand]; } - total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size); } CHECK_EQ(extra_users.size(), computation.instruction_count()); CHECK_EQ(total_sizes.size(), computation.instruction_count()); @@ -421,52 +480,87 @@ StatusOr> DFSMemoryScheduler( })); CHECK_EQ(sequence.size(), computation.instruction_count()); return sequence; -} +} // namespace xla StatusOr> ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - return ListScheduler::Run(computation, points_to_analysis, size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + return ListScheduler::Run(computation, points_to_analysis, size_function, + memory_by_computation); +} + +StatusOr> PostOrderMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + const auto& post_order = computation.MakeInstructionPostOrder(); + return std::vector{post_order.begin(), + post_order.end()}; } StatusOr> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - // We try both a list-scheduler based ordering and a DFS based ordering, and - // choose whichever returns a lower min-memory, not accounting for - // fragmentation. - // - // Note that this is just a heuristic. One obvious inaccuracy is that the - // memory required for sub-computations might be different when considered - // within the caller's context. But it's good enough for now. + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + // We try a few schedulers and choose whichever returns a lower min-memory, + // not accounting for fragmentation. + // - List is a scheduler that uses greedy heuristics. + // - DFS visits HLOs in postorder, with a heuristic to decide the order of + // children. + // - Postorder does not use any heuristics. + // List wins for most of our benchmarks; postorder-based schedulers win for + // some RNNs. TF_ASSIGN_OR_RETURN( std::vector list_sequence, - ListMemoryScheduler(computation, points_to_analysis, size_function)); + ListMemoryScheduler(computation, points_to_analysis, size_function, + memory_by_computation)); TF_ASSIGN_OR_RETURN( const int64 list_memory, MinimumMemoryForComputation(computation, list_sequence, points_to_analysis, size_function)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); - TF_ASSIGN_OR_RETURN( - std::vector dfs_sequence, - DFSMemoryScheduler(computation, points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN(std::vector dfs_sequence, + DFSMemoryScheduler(computation, points_to_analysis, + size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN( const int64 dfs_memory, MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, size_function)); VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); - if (list_memory <= dfs_memory) { + TF_ASSIGN_OR_RETURN( + std::vector post_order_sequence, + PostOrderMemoryScheduler(computation, points_to_analysis, size_function, + memory_by_computation)); + TF_ASSIGN_OR_RETURN( + const int64 post_order_memory, + MinimumMemoryForComputation(computation, post_order_sequence, + points_to_analysis, size_function)); + VLOG(2) << "Min-memory post order sequence: " + << HumanReadableNumBytes(post_order_memory); + + auto min_memory = std::min({dfs_memory, post_order_memory, list_memory}); + + if (min_memory == list_memory) { VLOG(2) << "Chose min-memory list sequence: " << HumanReadableNumBytes(list_memory); return list_sequence; - } else { + } else if (min_memory == dfs_memory) { VLOG(2) << "Chose min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); return dfs_sequence; + } else { + VLOG(2) << "Chose min-memory post_order sequence: " + << HumanReadableNumBytes(post_order_memory); + return post_order_sequence; } } @@ -477,24 +571,32 @@ CreateMemoryMinimizingSequence(const HloModule& module, SequentialHloOrdering::HloModuleSequence sequence; TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); - for (const auto* computation : module.MakeNonfusionComputations()) { - TF_ASSIGN_OR_RETURN( - sequence[computation], - CreateMemoryMinimizingSequence(*computation, *points_to_analysis, - size_function, algorithm)); + tensorflow::gtl::FlatMap memory_by_computation; + for (const auto* computation : module.MakeComputationPostOrder()) { + if (!computation->IsFusionComputation()) { + TF_ASSIGN_OR_RETURN(auto one_computation_sequence, + CreateMemoryMinimizingSequence( + *computation, *points_to_analysis, size_function, + algorithm, memory_by_computation)); + memory_by_computation[computation] = + MinimumMemoryForComputation(*computation, one_computation_sequence, + *points_to_analysis, size_function) + .ValueOrDie(); + sequence[computation] = std::move(one_computation_sequence); + } } return sequence; } StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm) { + const LogicalBuffer::SizeFunction& size_function) { CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); + tensorflow::gtl::FlatMap empty_map; return CreateMemoryMinimizingSequence(computation, *points_to_analysis, - size_function, algorithm); + size_function, nullptr, empty_map); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index 068e68383deb170ded1c9b09a8b7ceb8c4c0ab4b..49b927eefd24f4e26df781dd8d2b977bedba2b80 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -34,26 +34,47 @@ StatusOr MinimumMemoryForSequence( const SequentialHloOrdering::HloModuleSequence& module_sequence, const LogicalBuffer::SizeFunction& size_function); +// Returns the minimum memory required to compute the given computation, +// assuming no fragmentation. +StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function); + // A memory scheduler computes an execution sequence for the HLO instructions in // 'computation' that minimizes peak memory, given a points-to analysis result // that describes buffer aliasing, together with a target-specific size function // that maps a tensor's logical size to its padded size. typedef std::function>( const HloComputation&, const TuplePointsToAnalysis&, - const LogicalBuffer::SizeFunction&)> + const LogicalBuffer::SizeFunction&, + const tensorflow::gtl::FlatMap&)> MemorySchedulerAlgorithm; // List scheduler StatusOr> ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); // DFS-order scheduler StatusOr> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); + +// Naive Post Order scheduler +StatusOr> PostOrderMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); // The default scheduling algorithm. Runs both the list scheduler // and the DFS scheduler, and chooses whichever returns a lower min-memory, @@ -61,7 +82,9 @@ StatusOr> DFSMemoryScheduler( StatusOr> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); // Returns an HloModuleSequence which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes @@ -72,10 +95,10 @@ CreateMemoryMinimizingSequence(const HloModule& module, const MemorySchedulerAlgorithm& algorithm = {}); // Overload of above that computes the sequence for a single computation. +// Currently only used by the GPU backend. StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm = {}); + const LogicalBuffer::SizeFunction& size_function); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 92df7c1427f282ccdde2df494c41b3f2a98cf7b3..c018ba2ffc404d0c6a0d08b8f5c63a9f90888b70 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -190,5 +190,104 @@ ENTRY root { instructions_by_name.at("e"))); } +TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { + // %WhileCond (cond_param: f32[4]) -> pred[] { + // %cond_param = f32[4]{0} parameter(0) + // %constant = f32[1,4]{1,0} constant(f32[1,4] { { 0, 0, 0, 0 } }) + // ROOT %not-equal-to = pred[] not-equal-to( + // f32[4]{0} %cond_param, f32[1,4]{1,0} %constant) + // } + // %WhileBody (body_param: f32[4]) -> f32[4] { + // %body_param = f32[4]{0} parameter(0) + // %constant.1 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) + // ROOT %subtract = f32[4]{0} subtract( + // f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1) + // } + // %SubcomputationsNotAccounted () -> f32[2,4] { + // %constant.3 = f32[2,4]{1,0} constant( + // f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } }) + // %transpose = f32[2,4]{1,0} transpose( + // f32[2,4]{1,0} %constant.3), dimensions={0,1} + // %constant.2 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) + // %while = f32[4]{0} while(f32[1,4]{1,0} %constant.2), + // condition=%WhileCond, + // body=%WhileBody + // %broadcast = f32[2,4]{1,0} broadcast(f32[4]{0} %while), dimensions={0} + // ROOT %add = f32[2,4]{1,0} add( + // f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast) + // } + + auto module = CreateNewModule(); + const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); + const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); + + // param != 0 + // Needs 17 bytes + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "cond_param")); + HloInstruction* zero_vector = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{0, 0, 0, 0}}))); + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); + auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); + + // param - 1 + // Needs 16 bytes + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "body_param")); + HloInstruction* one_vector = body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + body_builder.AddInstruction(HloInstruction::CreateBinary( + r1f32, HloOpcode::kSubtract, body_param, one_vector)); + auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); + + // transpose(matrix) + bcast(while) + auto builder = HloComputation::Builder(TestName()); + HloInstruction* while_init = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + // Creates 16 bytes, ignoring subcomputations + HloInstruction* while_loop = + builder.AddInstruction(HloInstruction::CreateWhile( + r1f32, cond_computation, body_computation, while_init)); + + // Creates 32 bytes and frees 16 + HloInstruction* bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, while_loop, {0})); + + HloInstruction* matrix = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2( + {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}}))); + // Creates 32 bytes + HloInstruction* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(r2f32, matrix, {0, 1})); + + // Creates 32 bytes and frees 64 + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast)); + + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, + CreateMemoryMinimizingSequence( + *module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }, + ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + SequentialHloOrdering ordering(module.get(), sequence); + // This schedule is an example of List's greedy heuristics being suboptimal. + // The while_loop is more expensive than transpose, so it would have been + // better to schedule it first, instead of during the busy time. + EXPECT_TRUE(ordering.ExecutesBefore(transpose, while_loop)); + EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast)); + EXPECT_TRUE(ordering.ExecutesBefore(bcast, add)); + EXPECT_TRUE(ordering.ExecutesBefore(transpose, add)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 8a30cbf9cd622ffb64d345ddaf0dc88f34850bfc..7d6d0d9eaf70969c1a3762959233b561706398c2 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -106,9 +106,7 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { reduce_precision->mantissa_bits())); } -Status ShapeVerifier::HandleInfeed(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleInfeed(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { // Outfeed has a separate shape field for the value which is outfed to the @@ -116,7 +114,7 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { // produces no HLO value in the graph. if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) { - return InvalidArgument( + return InternalError( "Expected outfeed to have shape compatible with operand's shape %s, " "actual shape is %s:\n%s", ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), @@ -127,12 +125,10 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { } Status ShapeVerifier::HandleHostCompute(HloInstruction*) { - return tensorflow::Status::OK(); + return Status::OK(); } -Status ShapeVerifier::HandleRng(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleRng(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { return CheckShape( @@ -164,7 +160,7 @@ Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { } Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { @@ -183,7 +179,7 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { operand_shape.dimensions(operand_dimension)) << broadcast->ToString() << " operand shape " << operand_shape; } - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { @@ -191,7 +187,7 @@ Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape())); TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == ShapeUtil::ElementsIn(reshape->operand(0)->shape())); - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { @@ -200,22 +196,18 @@ Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { transpose->operand(0)->shape(), transpose->dimensions())); } -Status ShapeVerifier::HandleParameter(HloInstruction*) { - return tensorflow::Status::OK(); +Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { + return Status::OK(); } -Status ShapeVerifier::HandleFusion(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleCall(HloInstruction* call) { // The shape of kCall should match the shape of the computation it calls. return CheckShape(call, call->to_apply()->ComputeProgramShape().result()); } -Status ShapeVerifier::HandleCustomCall(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleSlice(HloInstruction* slice) { return CheckShape(slice, @@ -410,7 +402,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { if (fp_type == PRIMITIVE_TYPE_INVALID) { fp_type = subshape.element_type(); } else if (fp_type != subshape.element_type()) { - return FailedPrecondition( + return InternalError( "Seen floating point types of different precisions in " "%s, but mixed precision is disallowed.", instruction->ToString().c_str()); @@ -490,14 +482,14 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } } if (!compatible) { - return InvalidArgument( + return InternalError( "Expected instruction to have shape compatible with %s, actual " "shape is %s:\n%s", ShapeUtil::HumanString(inferred_shape).c_str(), ShapeUtil::HumanString(instruction->shape()).c_str(), instruction->ToString().c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::CheckShape(const HloInstruction* instruction, @@ -541,13 +533,13 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1, const HloInstruction* instr2) { if (instr1->channel_id() != instr2->channel_id()) { - return FailedPrecondition( + return InternalError( "Expected to have the same channel id, actual channel ids are: %s " "(%lld), %s (%lld)", instr1->ToString().c_str(), instr1->channel_id(), instr2->ToString().c_str(), instr2->channel_id()); } - return tensorflow::Status::OK(); + return Status::OK(); } string ComputationsToString( @@ -571,22 +563,22 @@ string ComputationsToString( Status VerifyHloStructure(HloModule* module) { for (const HloComputation* computation : module->computations()) { if (computation->parent() == nullptr) { - return FailedPrecondition("Computation %s has a null parent pointer", - computation->name().c_str()); + return InternalError("Computation %s has a null parent pointer", + computation->name().c_str()); } if (computation->parent() != module) { - return FailedPrecondition( + return InternalError( "Computation %s parent() does not point to parent module", computation->name().c_str()); } for (const HloInstruction* instruction : computation->instructions()) { if (instruction->parent() == nullptr) { - return FailedPrecondition("Instruction %s has a null parent pointer", - instruction->name().c_str()); + return InternalError("Instruction %s has a null parent pointer", + instruction->name().c_str()); } if (instruction->parent() != computation) { - return FailedPrecondition( + return InternalError( "Instruction %s parent() does not point to parent computation", instruction->name().c_str()); } @@ -602,7 +594,7 @@ Status VerifyHloStructure(HloModule* module) { for (int i = 0; i < instruction->operand_count(); ++i) { const HloInstruction* operand = instruction->operand(i); if (operand->parent() != instruction->parent()) { - return FailedPrecondition( + return InternalError( "Operand %d (%s) of instruction %s is in a different " "computation: %s vs %s", i, operand->name().c_str(), instruction->name().c_str(), @@ -612,14 +604,14 @@ Status VerifyHloStructure(HloModule* module) { } } } - return tensorflow::Status::OK(); + return Status::OK(); } Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { // The parent fusion instruction of the fusion computation must be 'fusion'. HloComputation* fused_computation = fusion->fused_instructions_computation(); if (fusion != fused_computation->FusionInstruction()) { - return FailedPrecondition( + return InternalError( "Instruction of fused computation does not match expected instruction " "%s.", fusion->ToString().c_str()); @@ -635,37 +627,37 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { for (auto* instruction : fused_computation->instructions()) { if (fused_root == instruction) { if (root_owned) { - return FailedPrecondition("Root appears more than once in %s.", - fusion->ToString().c_str()); + return InternalError("Root appears more than once in %s.", + fusion->ToString().c_str()); } root_owned = true; } for (int i = 0; i < fused_parameters.size(); ++i) { if (fused_parameters[i] == instruction) { if (parameter_owned[i]) { - return FailedPrecondition("Parameter appears more than once in %s.", - fusion->ToString().c_str()); + return InternalError("Parameter appears more than once in %s.", + fusion->ToString().c_str()); } parameter_owned[i] = true; } } } if (!root_owned) { - return FailedPrecondition("Root not found in computation of %s.", - fusion->ToString().c_str()); + return InternalError("Root not found in computation of %s.", + fusion->ToString().c_str()); } // Make sure all the parameter_owned entries are set for (int i = 0; i < parameter_owned.size(); i++) { if (!parameter_owned[i]) { - return FailedPrecondition("Parameter %d not found in computation of %s.", - i, fusion->ToString().c_str()); + return InternalError("Parameter %d not found in computation of %s.", i, + fusion->ToString().c_str()); } } // Fused root must have no users. if (fused_root->user_count() != 0) { - return FailedPrecondition("Root of %s may not have users.", - fusion->ToString().c_str()); + return InternalError("Root of %s may not have users.", + fusion->ToString().c_str()); } // All uses of fused instructions must be in the fusion computation, and every @@ -674,13 +666,13 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { fusion->fused_instructions_computation()->instructions()) { if (instruction != fused_root) { if (instruction->user_count() == 0) { - return FailedPrecondition( - "Non-root instruction %s in %s must have users.", - instruction->ToString().c_str(), fusion->ToString().c_str()); + return InternalError("Non-root instruction %s in %s must have users.", + instruction->ToString().c_str(), + fusion->ToString().c_str()); } for (auto& user : instruction->users()) { if (fused_computation != user->parent()) { - return FailedPrecondition( + return InternalError( "Non-root instruction %s in %s may not have external users.", instruction->ToString().c_str(), fusion->ToString().c_str()); } @@ -695,41 +687,40 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { for (auto fused_param : fused_parameters) { int64 param_no = fused_param->parameter_number(); if (param_no < 0) { - return FailedPrecondition( - "Unexpected negative parameter number %lld in %s.", param_no, - fusion->ToString().c_str()); + return InternalError("Unexpected negative parameter number %lld in %s.", + param_no, fusion->ToString().c_str()); } if (param_no >= fused_parameters.size()) { - return FailedPrecondition( + return InternalError( "Unexpected parameter number %lld in %s: higher then number of " "parameters %lu.", param_no, fusion->ToString().c_str(), fused_parameters.size()); } if (parameter_numbers[param_no]) { - return FailedPrecondition( + return InternalError( "Did not expect parameter number %lld more than once in %s.", param_no, fusion->ToString().c_str()); } parameter_numbers[param_no] = true; if (!ShapeUtil::Compatible(fused_param->shape(), fusion->operand(param_no)->shape())) { - return FailedPrecondition( + return InternalError( "Shape mismatch between parameter number %lld and its operand in %s.", param_no, fusion->ToString().c_str()); } } - // Make sure all the parameter_numbers entries were seen + // Make sure all the parameter_numbers entries were seen. for (int i = 0; i < parameter_numbers.size(); i++) { if (!parameter_numbers[i]) { - return FailedPrecondition("Did not see parameter number %d in %s.", i, - fusion->ToString().c_str()); + return InternalError("Did not see parameter number %d in %s.", i, + fusion->ToString().c_str()); } } // TODO(b/65423525): We'd like to check that all operands are distinct. // This is currently disabled due to the invariant being violated by // multi-output fusion. - return tensorflow::Status::OK(); + return Status::OK(); } Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { @@ -778,7 +769,7 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { "init: %s, body: %s", init->ToString().c_str(), body_root->ToString().c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { @@ -796,7 +787,7 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { ShapeUtil::HumanString(operand_shape).c_str()); } } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr HloVerifier::Run(HloModule* module) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 6208887547a14d22b512ef308dd2668af2f4468d..1392a78097aa026b2f7cffa2b0135402d3ca7ae5 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -82,9 +82,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleGather(HloInstruction* gather) override; - Status FinishVisit(HloInstruction*) override { - return tensorflow::Status::OK(); - } + Status FinishVisit(HloInstruction*) override { return Status::OK(); } protected: // Check the instruction's shape against the shape given by ShapeInference diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index 13e4557317f74b3fb46f07fb91c339fd2f34752f..dc3bfce0c495bc40a2df7b985cab67e02a3e15ce 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -27,6 +27,7 @@ using tensorflow::strings::HumanReadableElapsedTime; using tensorflow::strings::HumanReadableNumBytes; using tensorflow::strings::Printf; using tensorflow::strings::StrAppend; +using tensorflow::strings::StrCat; string HumanReadableProfileBuilder::ToString() const { string s; @@ -35,20 +36,26 @@ string HumanReadableProfileBuilder::ToString() const { computation_name_.c_str(), HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)).c_str()); - auto append_op = [&](const OpInfo& op) { + auto print_op = [&](const OpInfo& op) { + // Skip ops with 0 optimal seconds and 0 actual cycles. These are ops that + // were expected to be free and are actually free -- things like (on most + // backends) kParameter or kConstant HLOs. There's no need to clutter the + // profile with these. + if (op.optimal_seconds == 0 && op.cycles == 0) { + return; + } + string bytes_per_sec; string bytes_per_cycle; - if (op.cycles <= 0 || op.bytes_accessed < 0) { - bytes_per_sec = ""; - bytes_per_cycle = ""; - } else { - bytes_per_sec = - HumanReadableNumBytes(op.bytes_accessed / CyclesToSeconds(op.cycles)); + if (op.cycles > 0 && op.bytes_accessed >= 0) { + bytes_per_sec = StrCat( + HumanReadableNumBytes(op.bytes_accessed / CyclesToSeconds(op.cycles)), + "/s"); + double bpc = static_cast(op.bytes_accessed) / op.cycles; if (op.bytes_accessed > op.cycles) { - bytes_per_cycle = HumanReadableNumBytes(op.bytes_accessed / op.cycles); + bytes_per_cycle = StrCat(HumanReadableNumBytes(bpc), "/cycle"); } else { - bytes_per_cycle = - Printf("%.3fB", static_cast(op.bytes_accessed) / op.cycles); + bytes_per_cycle = Printf("%.3fB/cycle", bpc); } } @@ -59,14 +66,16 @@ string HumanReadableProfileBuilder::ToString() const { double nsecs = op.cycles / clock_rate_ghz_; Appendf(&s, - "%15lld cycles (%6.2f%%) :: %12.1f usec (%12.1f optimal) :: %18s " - ":: %18s :: %12s/s :: %12s/cycle :: %s\n", + "%15lld cycles (%6.2f%%) :: %12.1f usec %22s :: %18s " + ":: %18s :: %14s :: %16s :: %s\n", op.cycles, cycles_percent, CyclesToMicroseconds(op.cycles), - op.optimal_seconds * 1e6, + op.optimal_seconds < 0 + ? "" + : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(), op.flop_count <= 0 - ? "" + ? "" : HumanReadableNumFlops(op.flop_count, nsecs).c_str(), - op.transcendental_count <= 0 ? "" + op.transcendental_count <= 0 ? "" : HumanReadableNumTranscendentalOps( op.transcendental_count, nsecs) .c_str(), @@ -78,24 +87,26 @@ string HumanReadableProfileBuilder::ToString() const { int64 total_transcendentals = 0.; int64 total_bytes = 0; for (const auto& op : op_infos_) { - optimal_seconds_sum += op.optimal_seconds; - total_flops += op.flop_count; - total_transcendentals += op.transcendental_count; - total_bytes += op.bytes_accessed; + if (op.optimal_seconds > 0) { + optimal_seconds_sum += op.optimal_seconds; + } + total_flops += std::max(op.flop_count, int64{0}); + total_transcendentals += std::max(op.transcendental_count, int64{0}); + total_bytes += std::max(op.bytes_accessed, int64{0}); } VLOG(1) << "Total floating point ops: " << total_flops; - append_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops, - total_transcendentals, total_bytes, optimal_seconds_sum}); + print_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops, + total_transcendentals, total_bytes, optimal_seconds_sum}); - // Sort ops in decreasing order of cycles. + // Sort ops in decreasing order of cycles, and print them. std::vector sorted_ops(op_infos_); std::sort( sorted_ops.begin(), sorted_ops.end(), [](const OpInfo& a, const OpInfo& b) { return a.cycles > b.cycles; }); for (const auto& op : sorted_ops) { - append_op(op); + print_op(op); } if (total_cycles_ <= 0) { @@ -109,8 +120,20 @@ string HumanReadableProfileBuilder::ToString() const { table.SetMetricName("microseconds above estimated optimum"); table.SetEntryName("ops"); table.SetShowCategoryTable(); + table.SetShowAllEntries(); float total_discrepancy_in_microseconds = 0.0f; - for (const auto& op : sorted_ops) { + for (const auto& op : op_infos_) { + // Skip ops with < 0 optimal seconds. These are ops for which we don't + // know the optimal time. + if (op.optimal_seconds < 0) { + continue; + } + // Also skip ops with 0 actual cycles. These ops were free; there's no + // need to clutter the "above estimated optimum" table with them, + // because they can't be optimized further. + if (op.cycles == 0) { + continue; + } MetricTableReport::Entry entry; entry.text = op.name; entry.short_text = op.short_name; @@ -128,7 +151,14 @@ string HumanReadableProfileBuilder::ToString() const { table.SetMetricName("microseconds"); table.SetEntryName("ops"); table.SetShowCategoryTable(); - for (const auto& op : sorted_ops) { + table.SetShowAllEntries(); + for (const auto& op : op_infos_) { + // Skip ops with 0 optimal seconds and 0 actual cycles. As in + // print_op(), these are uninteresting because they're expected to be + // free, and they were actually free. + if (op.cycles == 0 && op.optimal_seconds == 0) { + continue; + } MetricTableReport::Entry entry; entry.text = op.name; entry.short_text = op.short_name; diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h index fc24acd2713f4cd8af2816ffdf085e84a4920cbc..6f56c3aa82e9d1c942fd67ff7a5948cf2e54370d 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.h +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h @@ -32,7 +32,7 @@ class HumanReadableProfileBuilder { explicit HumanReadableProfileBuilder(tensorflow::StringPiece computation_name, int64 total_cycles, double clock_rate_ghz) - : computation_name_(computation_name.ToString()), + : computation_name_(std::string(computation_name)), total_cycles_(total_cycles), clock_rate_ghz_(clock_rate_ghz) { CHECK_GE(clock_rate_ghz, 1e-9); @@ -41,15 +41,17 @@ class HumanReadableProfileBuilder { int64 total_cycles() const { return total_cycles_; } // Adds an operation to the profile. If you don't know the number of - // floating-point ops or bytes touched by the op, pass -1 for that param. + // floating-point ops or bytes touched by the op, or if you don't know how + // fast it would run optimally, pass -1 for that param. void AddOp(tensorflow::StringPiece op_name, tensorflow::StringPiece short_name, tensorflow::StringPiece category, int64 cycles, int64 flop_count, int64 transcendental_count, int64 bytes_accessed, float optimal_seconds) { - op_infos_.push_back( - {op_name.ToString(), short_name.ToString(), category.ToString(), cycles, - flop_count, transcendental_count, bytes_accessed, optimal_seconds}); + op_infos_.push_back({std::string(op_name), std::string(short_name), + std::string(category), cycles, flop_count, + transcendental_count, bytes_accessed, + optimal_seconds}); } // Gets the human-readable profile. @@ -61,10 +63,10 @@ class HumanReadableProfileBuilder { string short_name; string category; int64 cycles; - int64 flop_count; + int64 flop_count; // -1 if unknown int64 transcendental_count; - int64 bytes_accessed; - float optimal_seconds; + int64 bytes_accessed; // -1 if unknown + float optimal_seconds; // -1 if unknown }; double CyclesToSeconds(int64 cycles) const { diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc new file mode 100644 index 0000000000000000000000000000000000000000..15b2d8f4990735c56f105e7c1b9b7dc70609d898 --- /dev/null +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -0,0 +1,269 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/indexed_array_analysis.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { +namespace gtl = ::tensorflow::gtl; + +namespace { +using Analysis = IndexedArrayAnalysis; +using UnknownArray = Analysis::UnknownArray; +using ConstantArray = Analysis::ConstantArray; +using ScalarIndexedArray = Analysis::ScalarIndexedArray; +} // namespace + +string IndexedArrayAnalysis::ToString(Array* root) { + switch (root->kind()) { + case Array::kUnknown: { + auto* unknown_tensor = root->as(); + return tensorflow::strings::StrCat("%", + unknown_tensor->instruction().name()); + } + + case Array::kConstant: { + return tensorflow::strings::StrCat( + "(constant ", ShapeUtil::HumanString(root->shape()), ")"); + } + + case Array::kScalarIndexedConstant: + case Array::kScalarIndexed: { + auto* indexed_array = root->as(); + string name = root->kind() == Array::kScalarIndexedConstant + ? "scalar-indexed-const" + : "scalar-indexed"; + return tensorflow::strings::StrCat( + "(", name, " ", ToString(indexed_array->source()), " ", + ToString(indexed_array->indices()), " ", indexed_array->source_dim(), + "->[", tensorflow::str_util::Join(indexed_array->output_dims(), ","), + "])"); + } + } +} + +Analysis::Array* IndexedArrayAnalysis::GetArrayFor( + const HloInstruction* instr) { + auto it = cache_.find(instr); + if (it != cache_.end()) { + return it->second; + } + + TraverseAndPopulateCache(instr); + return FindOrDie(cache_, instr); +} + +void IndexedArrayAnalysis::TraverseAndPopulateCache( + const HloInstruction* root) { + // Depth first search over the DAG, invoking ComputeArrayFor in post order. + // The HLO instructions already in the cache are considered leaves. + + gtl::InlinedVector stack; + + enum DfsState { kDiscovered, kVisited }; + gtl::FlatMap dfs_state_map; + + stack.push_back(root); + InsertOrDie(&dfs_state_map, root, kDiscovered); + + do { + const HloInstruction* instr = stack.back(); + if (cache_.count(instr)) { + stack.pop_back(); + continue; + } + + switch (FindOrDie(dfs_state_map, instr)) { + case kDiscovered: { + for (const HloInstruction* operand : instr->operands()) { + if (!cache_.count(operand)) { + stack.push_back(operand); + CHECK(!dfs_state_map.count(operand) || + dfs_state_map[operand] == kDiscovered); + dfs_state_map[operand] = kDiscovered; + } + } + dfs_state_map[instr] = kVisited; + break; + } + + case kVisited: + stack.pop_back(); + InsertOrDie(&cache_, instr, ComputeArrayFor(instr)); + break; + } + } while (!stack.empty()); +} + +Analysis::Array* IndexedArrayAnalysis::ComputeArrayFor( + const HloInstruction* instr) { + Array* computed_array; + switch (instr->opcode()) { + default: + computed_array = nullptr; + break; + case HloOpcode::kConstant: + computed_array = ComputeArrayForConstant(instr->literal()); + break; + case HloOpcode::kGather: + computed_array = ComputeArrayForGather( + instr->shape(), instr->gather_dimension_numbers(), + instr->gather_window_bounds(), FindOrDie(cache_, instr->operand(0)), + FindOrDie(cache_, instr->operand(1))); + break; + } + + if (!computed_array) { + computed_array = Construct(instr); + } + + return computed_array; +} + +Analysis::Array* IndexedArrayAnalysis::ComputeArrayForConstant( + const Literal& literal) { + return Construct(&literal); +} + +ScalarIndexedArray* IndexedArrayAnalysis::FoldGatherOfGather( + ScalarIndexedArray* source, Array* indices, int64 source_dim, + tensorflow::gtl::ArraySlice output_dims, Shape shape) { + // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)). + // `source` is the inner Gather(A, X). + + Array* a = source->source(); + Array* x = source->indices(); + Array* y = indices; + + // This bit is slightly tricky, so we do a naive "simulation" of the two + // consecutive gather operations to infer what the composed gather should look + // like. + + enum class IndexComponent { Ungathered, GatheredFirst, GatheredSecond }; + + std::vector simulated_index(a->shape().dimensions_size(), + IndexComponent::Ungathered); + + // Simulate the first gather. + simulated_index.erase(simulated_index.begin() + source->source_dim()); + for (int64 gather_dim : source->output_dims()) { + simulated_index.insert(simulated_index.begin() + gather_dim, + IndexComponent::GatheredFirst); + } + + // Simulate the second gather. + simulated_index.erase(simulated_index.begin() + source_dim); + for (int64 output_dim : output_dims) { + simulated_index.insert(simulated_index.begin() + output_dim, + IndexComponent::GatheredSecond); + } + + int64 source_dim_for_index_array = + FindIndex(source->output_dims(), source_dim); + CHECK_NE(source_dim_for_index_array, source->output_dims().size()); + + std::vector output_dims_for_index_array; + int64 gathered_index_components_seen = 0; + for (IndexComponent simulation_dim : simulated_index) { + if (simulation_dim == IndexComponent::GatheredSecond) { + output_dims_for_index_array.push_back(gathered_index_components_seen); + } + if (simulation_dim != IndexComponent::Ungathered) { + gathered_index_components_seen++; + } + } + + std::vector dim_sizes_for_composed_index; + std::vector output_dims_for_new_gather; + for (int64 i = 0, e = simulated_index.size(); i < e; i++) { + if (simulated_index[i] != IndexComponent::Ungathered) { + dim_sizes_for_composed_index.push_back(shape.dimensions(i)); + output_dims_for_new_gather.push_back(i); + } + } + + Array* inner_indices = ConstructScalarIndexedArray( + x, y, source_dim_for_index_array, output_dims_for_index_array, + ShapeUtil::MakeShape(x->shape().element_type(), + dim_sizes_for_composed_index)); + return ConstructScalarIndexedArray(a, inner_indices, source->source_dim(), + output_dims_for_new_gather, + std::move(shape)); +} + +Analysis::Array* IndexedArrayAnalysis::ComputeArrayForGather( + const Shape& shape, const GatherDimensionNumbers& dim_numbers, + tensorflow::gtl::ArraySlice window_bounds, Array* source, + Array* indices) { + if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { + return nullptr; + } + + CHECK_EQ(dim_numbers.gather_dims_to_operand_dims_size(), 1); + if (!c_binary_search(dim_numbers.elided_window_dims(), + dim_numbers.gather_dims_to_operand_dims(0))) { + return nullptr; + } + + int64 source_dim = dim_numbers.gather_dims_to_operand_dims(0); + std::vector output_dims; + for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { + if (!c_binary_search(dim_numbers.output_window_dims(), i)) { + output_dims.push_back(i); + } + } + + if (auto* indexed = dynamic_cast(source)) { + auto it = c_find(indexed->output_dims(), source_dim); + if (it != indexed->output_dims().end()) { + return FoldGatherOfGather(indexed, indices, source_dim, output_dims, + shape); + } + } else if (auto* constant = dynamic_cast(source)) { + return Construct(constant, indices, source_dim, + output_dims, shape); + } + + return Construct(source, indices, source_dim, output_dims, + shape); +} + +tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const { + return "indexed-array-analysis-printer-pass"; +} + +StatusOr IndexedArrayAnalysisPrinterPass::Run(HloModule* module) { + if (!VLOG_IS_ON(2)) { + return false; + } + + IndexedArrayAnalysis analysis; + for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* instr : computation->instructions()) { + auto* t = analysis.GetArrayFor(instr); + if (!dynamic_cast(t) && !dynamic_cast(t)) { + VLOG(2) << instr->ToString() << " -> " << analysis.ToString(t); + } + } + } + + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h new file mode 100644 index 0000000000000000000000000000000000000000..b132a8f25153d2e86e8aa477fdb851f1c9c8e719 --- /dev/null +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -0,0 +1,298 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace xla { + +// IndexedArrayAnalysis decides if an HLO instruction can be rewritten as a +// gather from another array. It does this by mapping HLO instructions to +// instances of IndexedArrayAnalysis::Array, which can be inspected to discover +// whether said HLO is equivalent to a gather. +class IndexedArrayAnalysis { + public: + // IndexedArrayAnalysis maps each HLO instruction to an instance of a Array. + // Array really just a sum type of the classes that inherit from it. The + // meaning of each of the subtypes is documented on the subtype declaration. + // + // Array instances are immutable once created. + class Array { + public: + enum Kind { kUnknown, kConstant, kScalarIndexedConstant, kScalarIndexed }; + + virtual Kind kind() const = 0; + virtual const Shape& shape() const = 0; + + // Does a checked downcast from `Array` to `T` which must be one of its + // subtypes. + template + T* as() { + static_assert((std::is_base_of::value), + "target type not derived from source type"); + // We skip the CHECK and hence the dynamic_cast if RTTI is disabled. +#if !defined(__GNUC__) || defined(__GXX_RTTI) + CHECK_NE(dynamic_cast(this), nullptr); +#endif // !defined(__GNUC__) || defined(__GXX_RTTI) + + return static_cast(this); + } + + virtual ~Array() = default; + + Array& operator=(const Array& other) = delete; + }; + + // Represents an HLO instruction that was not analyzable by this + // IndexedArrayAnalysis. Instances of UnknownArray just wrap an existing + // HloInstruction. + class UnknownArray : public Array { + public: + Kind kind() const override { return kUnknown; } + const Shape& shape() const override { return instruction().shape(); } + const HloInstruction& instruction() const { return instruction_; } + + private: + explicit UnknownArray(const HloInstruction* instr) : instruction_(*instr) {} + + const HloInstruction& instruction_; + + friend class IndexedArrayAnalysis; + }; + + // Represents a constant value. This constant value may be present in the HLO + // module being analyzed, or it could have been created on the fly by the + // analysis. + class ConstantArray : public Array { + public: + Kind kind() const override { return kConstant; } + const Shape& shape() const override { return literal()->shape(); } + const Literal* literal() const { return literal_; } + + private: + explicit ConstantArray(const Literal* literal) : literal_(literal) {} + const Literal* literal_; + + friend class IndexedArrayAnalysis; + }; + + // --------------------------------------------------------------------------- + // Indexed Array Overview + // --------------------------------------------------------------------------- + // + // ScalarIndexedArray and ScalarIndexedConstantArray form the core of this + // analysis. ScalarIndexedConstantArray is just a specialization of + // ScalarIndexedArray so we will only discuss ScalarIndexedArray in this + // overview. + // + // A ScalarIndexedArray represents an array that can be computed by indexing + // into a "source" array using an "indices" tensor. A simple example is a + // gather operation gathering 12 rows out of a [100,100] matrix -- such an + // operation will be represented by an instance of a ScalarIndexedArray with + // the [100,100] matrix as the "source" array and the [12]-shaped indices + // array as the "indices" tensor. The ScalarIndexedArray operation itself + // will be of shape [12,100] (assuming we were gathering with axis=0). + // + // Gather operations are not the only operation that maps to + // ScalarIndexedArray instances (if that were true there would be little point + // in having a separate analysis). We can often infer ScalarIndexedArrays for + // other operations too. For instance, consider: + // + // %source = f32[100,100] constant + // %indices = s32[12] ... + // %gather = f32[12,100] ... gather from %source using %indices at axis 0 + // %dot = dot(%gather, other_constant) [canonical contracting dims] + // + // The dot operation itself is also a ScalarIndexedArray with source = + // dot(constant, other_constant) and indices = %indices. A reshape of %gather + // to [12,5,20] too is a ScalarIndexedArray with source = an appropriately + // reshaped constant and indices = %indices. + + // Represents the result of a gather operation. This gather operation may + // explicitly be present in the HLO module being analyzed, or it could have + // been created on the fly by the analysis. + // + // An instance of ScalarIndexedArray represents a array whose I'th element can + // be mapped to the J'th element of the `source` array (where I and J are + // multidimensional indices) in this way: + // + // I' = remove components at positions `output_dims` from I + // G' = remove components not at positions `output_dims` from I + // T = indices[G'] + // J = I' with T inserted at position `source_dim` + // + // For example, if source is of shape [11,13,17,19], indices is of shape + // [23,29], output_dims is [0,2] and source_dim is 2 then the output is of + // shape [23,11,29,19] and the output index [A,B,C,D,E] is mapped to the input + // index [B,D,indices[A,C],E]. + class ScalarIndexedArray : public Array { + public: + Kind kind() const override { return kScalarIndexed; } + const Shape& shape() const override { return shape_; } + + Array* source() const { return source_; } + Array* indices() const { return indices_; } + int64 source_dim() const { return source_dim_; } + tensorflow::gtl::ArraySlice output_dims() const { + return output_dims_; + } + + private: + explicit ScalarIndexedArray(Array* source, Array* indices, int64 source_dim, + std::vector output_dims, Shape shape) + : source_(source), + indices_(indices), + source_dim_(source_dim), + output_dims_(std::move(output_dims)), + shape_(std::move(shape)) {} + + Array* source_; + Array* indices_; + int64 source_dim_; + std::vector output_dims_; + Shape shape_; + + friend class IndexedArrayAnalysis; + }; + + // A ScalarIndexedConstantArray is just a ScalarIndexedArray constrained to + // have a ConstantArray instance as the source. This is an ergonomic + // concession -- in theory it is possible to just keep ScalarIndexedArray and + // check source()->kind(). + class ScalarIndexedConstantArray : public ScalarIndexedArray { + public: + Kind kind() const override { return kScalarIndexedConstant; } + + const Literal& literal() const { + return *source()->as()->literal(); + } + + private: + explicit ScalarIndexedConstantArray(Array* source, Array* indices, + int64 source_dim, + std::vector output_dims, + Shape shape) + : ScalarIndexedArray(source, indices, source_dim, + std::move(output_dims), std::move(shape)) { + CHECK(dynamic_cast(source)); + } + + friend class IndexedArrayAnalysis; + }; + + // Returns an Array instance for `instr`. The IndexedArrayAnalysis instance + // keeps ownership of the returned Array instance. + // + // Caching Behavior: IndexedArrayAnalysis has a cache mapping HLO + // instructions to IndexedArrayAnalysis::Array instances. This entire cache + // becomes stale and may cause the analysis to return incorrect results if any + // transitive operand (stopping at the containing computation) is modified for + // any HLO instruction on which GetArrayFor has been invoked. + // + // NB! By inspecting the implementation, you may be able to infer a stronger + // caching guarantee than what is mentioned above. Nevertheless, what is + // stated above is the contract. + Array* GetArrayFor(const HloInstruction* instr); + + // Pretty-prints the expression rooted at `root`. + string ToString(Array* root); + + private: + // Helper function that ensures that every HLO instruction that is + // transitively used by `root` has an entry in `cache_`. + void TraverseAndPopulateCache(const HloInstruction* root); + + // Creates an Array instance for `instr` under the assumption that all + // operations of `instr` are present in `cache_`. + Array* ComputeArrayFor(const HloInstruction* instr); + + Array* ComputeArrayForConstant(const Literal& literal); + + Array* ComputeArrayForGather(const Shape& shape, + const GatherDimensionNumbers& dim_numbers, + tensorflow::gtl::ArraySlice window_bounds, + Array* source, Array* indices); + + // This tries to fold a ScalarIndexedArray which has another + // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a + // ScalarIndexedArray as indices. If `source` happened to be a + // ScalarIndexedConstantArray this can result in an expression that is more + // canonical. + // + // As an example, consider a gather operation, G0, gathering 7 elements from + // an array "Arr" of shape [100] resulting in an array of shape [7], and a + // second gather operation, G1, which gathers 3 elements out of the result of + // G0 resulting in an array of shape [3]. Let the indices uses by G0 be I0 + // (of shape [7]) and the indices used by G1 be I1 (of shape [3]). We can + // instead rewrite G1 to gather directly from "Arr" with the three indices + // from I0 as per I1. In other words, we can rewrite: + // + // G0 = [Arr[i] for i in I0] + // G1 = [G0[i] for i in I1] + // + // into + // + // I2 = [I0[i] for i in I1] + // G1 = [Arr[i] for i in I2] + ScalarIndexedArray* FoldGatherOfGather( + ScalarIndexedArray* source, Array* indices, int64 source_dim, + tensorflow::gtl::ArraySlice output_dims, Shape shape); + + template + T* Construct(Args&&... args) { + T* new_tensor = new T(std::forward(args)...); + owned_tensors_.push_back(std::unique_ptr(new_tensor)); + return new_tensor; + } + + ScalarIndexedArray* ConstructScalarIndexedArray( + Array* source, Array* indices, int64 source_dim, + std::vector output_dims, Shape shape) { + if (source->kind() == Array::kConstant) { + return Construct(source, indices, source_dim, + std::move(output_dims), + std::move(shape)); + } else { + return Construct(source, indices, source_dim, + std::move(output_dims), + std::move(shape)); + } + } + + std::vector> owned_tensors_; + std::vector> owned_literals_; + tensorflow::gtl::FlatMap cache_; +}; + +// A pass that prints all non-trivial results returned by IndexedArrayAnalysis. +// This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to +// unconditionally add to the regular HLO pass pipeline. +class IndexedArrayAnalysisPrinterPass : public HloPassInterface { + public: + tensorflow::StringPiece name() const override; + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b2731b7c51a45c4f9b713d99ef3e4623ad2c9c83 --- /dev/null +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -0,0 +1,191 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/indexed_array_analysis.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" + +namespace xla { +namespace { +class IndexedArrayAnalysisTest : public HloVerifiedTestBase { + protected: + void AssertArrayForRootExpressionIs(const string& hlo_text, + const string& root_expression) { + IndexedArrayAnalysis indexed_tensor_analysis; + ParseAndVerifyModule(hlo_text); + + string result = + indexed_tensor_analysis.ToString(indexed_tensor_analysis.GetArrayFor( + module().entry_computation()->root_instruction())); + LOG(INFO) << result; + ASSERT_EQ(result, root_expression); + } +}; + +TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneGather) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[5] parameter(1) + ROOT gather = s32[5,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, + "(scalar-indexed %operand %indices 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneConstantGather) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + indices = s32[5] parameter(0) + ROOT gather = s32[5,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, "(scalar-indexed-const (constant s32[3,3]) %indices 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + indices_a = s32[5] parameter(0) + indices_b = s32[2] parameter(1) + gather_a = s32[5,3] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} + ROOT gather_b = s32[2,3] gather(gather_a, indices_b), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed-const (constant s32[3,3]) (scalar-indexed %indices_a " + "%indices_b 0->[0]) 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_ManyToOneWithOneToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,2] parameter(0) + indices_a = s32[5,7] parameter(1) + indices_b = s32[2] parameter(2) + gather_a = s32[5,3,7] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3,1} + ROOT gather_b = s32[5,3,2] gather(gather_a, indices_b), + output_window_dims={0,1}, + elided_window_dims={2}, + gather_dims_to_operand_dims={2}, + index_vector_dim=1, + window_bounds={5,3,1} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, + "(scalar-indexed %operand (scalar-indexed " + "%indices_a %indices_b 1->[1]) 1->[0,2])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOneWithManyToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,6] parameter(0) + indices_a = s32[2] parameter(1) + indices_b = s32[5,7] parameter(2) + gather_a = s32[2,6] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,6} + ROOT gather_b = s32[5,6,7] gather(gather_a, indices_b), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,6} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, + "(scalar-indexed %operand (scalar-indexed " + "%indices_a %indices_b 0->[0,1]) 0->[0,2])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_ManyToOneWithManyToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,2] parameter(0) + indices_a = s32[5,7] parameter(1) + indices_b = s32[4,8] parameter(2) + gather_a = s32[5,3,7] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3,1} + ROOT gather_b = s32[4,5,3,8] gather(gather_a, indices_b), + output_window_dims={1,2}, + elided_window_dims={2}, + gather_dims_to_operand_dims={2}, + index_vector_dim=2, + window_bounds={5,3,1} +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed %operand (scalar-indexed %indices_a %indices_b " + "1->[0,2]) 1->[0,1,3])"); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 7aa1c7c8358318d02a000d968a2672123400ad6e..d2af261008f40ee83e0676cfc7e67c45f8be1844 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -71,7 +71,7 @@ TEST_F(InlinerTest, MapMax) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto expected = Literal::CreateR1({4, 3, 3, 4}); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } // Test that `constant` function is changed to `broadcast`. @@ -105,7 +105,7 @@ TEST_F(InlinerTest, MapConstant) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto expected = Literal::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } TEST_F(InlinerTest, MapSubtractOppositeOrder) { @@ -143,7 +143,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto expected = Literal::CreateR1({3, 1, -1, -3}); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index dc1a39e9fa9fd3ef5c55bd86309fe23f5ef51dd5..cb6c98c48171a06539499b723a8d8b7aa0ccc96a 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -28,6 +28,25 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace xla { +namespace { +// These nodes can always be duplicated into consumers, even if +// InstructionFusion::may_duplicate_ is false. +// +// In general these should be nodes that get *cheaper* the more they're +// duplicated (and fused into consumers). +// +// TODO(jlebar): Duplicating instructions when we have a variable called "may +// duplicate" that's equal to false is not pretty. +bool IsAlwaysDuplicable(const HloInstruction& instruction) { + // We are always willing to duplicate a widening type-conversion instruction + // if it means we can fuse the convert into a consumer. This allows the + // consumer to read less memory, which is almost always a performance win. + return instruction.opcode() == HloOpcode::kConvert && + ShapeUtil::ByteSizeOf(instruction.operand(0)->shape()) < + ShapeUtil::ByteSizeOf(instruction.shape()); +} +} // namespace + /*static*/ bool InstructionFusion::IsExpensive( const HloInstruction& instruction) { switch (instruction.opcode()) { @@ -101,11 +120,13 @@ namespace xla { case HloOpcode::kDivide: case HloOpcode::kDot: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFft: case HloOpcode::kFusion: case HloOpcode::kGather: case HloOpcode::kHostCompute: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kMap: case HloOpcode::kParameter: case HloOpcode::kPower: @@ -393,12 +414,9 @@ StatusOr InstructionFusion::Run(HloModule* module) { return changed; } -HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, - HloInstruction* consumer) { +HloInstruction* InstructionFusion::AddFusionInstruction( + HloInstruction* producer, HloInstruction* consumer) { HloInstruction* fusion_instruction; - - VLOG(2) << "Fusing " << producer->ToString() << " into " - << consumer->ToString(); auto kind = ChooseKind(producer, consumer); if (consumer->opcode() == HloOpcode::kFusion) { fusion_instruction = consumer; @@ -410,17 +428,35 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, HloInstruction::CreateFusion(consumer->shape(), kind, consumer)); TF_CHECK_OK(computation_->ReplaceInstruction(consumer, fusion_instruction)); } + return fusion_instruction; +} +HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, + HloInstruction* consumer) { + VLOG(2) << "Fusing " << producer->ToString() << " into " + << consumer->ToString(); + HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer); fusion_instruction->FuseInstruction(producer); return fusion_instruction; } +HloInstruction* InstructionFusion::FuseIntoMultiOutput( + HloInstruction* producer, HloInstruction* consumer) { + VLOG(2) << "Multi-output fusing " << producer->ToString() << " into " + << consumer->ToString(); + HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer); + fusion_instruction->FuseInstructionIntoMultiOutput(producer); + return fusion_instruction; +} + bool InstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); + // Cost condition: don't duplicate expensive instructions. if (FusionWouldDuplicate(*producer, *consumer) && - (is_expensive_(*producer) || !may_duplicate_)) { + (!may_duplicate_ || is_expensive_(*producer)) && + !IsAlwaysDuplicable(*producer)) { return false; } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index 2ea1fcf937ceaf2cce3f8ed0891399384d93dbd0..c3c2ed0aaa81d6f346ec6e70d9c8b3b923e0a3d2 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -70,6 +70,13 @@ class InstructionFusion : public HloPassInterface { virtual HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer); + // Creates a new fusion instruction containing `producer` and `consumer`. A + // tuple is added as the fusion instruction's root, which consumes from both, + // `producer` and `consumer`. This style of fusion is referred to as + // multi-output fusion. + virtual HloInstruction* FuseIntoMultiOutput(HloInstruction* producer, + HloInstruction* consumer); + // An "effectively unary" operation is one that has at most one "large" // input with the others being negligible in terms of memory usage. // We use "has a smaller true rank than the output" as a heuristic @@ -95,6 +102,9 @@ class InstructionFusion : public HloPassInterface { // The set of producers whose consumers we cannot fuse into. using DoNotFuseSet = std::unordered_set; + HloInstruction* AddFusionInstruction(HloInstruction* producer, + HloInstruction* consumer); + // Whether or not we can fuse producer into consumer on all paths // from the producer to the consumer where nodes are HLOs and edges are uses. bool CanFuseOnAllPaths(HloInstruction* producer, HloInstruction* consumer, diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index e78b99a80cf41318faa1cb709428b8ba0f531944..df109df7877eefe4c337f93cc5a3a7a48e2e76c7 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -21,8 +21,95 @@ limitations under the License. namespace xla { +namespace op = xla::testing::opcode_matchers; + using InstructionFusionTest = HloTestBase; +// Subclass of InstructionFusion exposing the protected methods Fuse and +// FuseIntoMultiOutput for testing. +class InstructionFusionForTesting : public InstructionFusion { + public: + explicit InstructionFusionForTesting(HloModule* module) + : InstructionFusion(InstructionFusion::IsExpensive) { + module_ = module; + computation_ = module->entry_computation(); + } + + HloInstruction* Fuse(HloInstruction* producer, + HloInstruction* consumer) override { + return InstructionFusion::Fuse(producer, consumer); + } + + HloInstruction* FuseIntoMultiOutput(HloInstruction* producer, + HloInstruction* consumer) override { + return InstructionFusion::FuseIntoMultiOutput(producer, consumer); + } +}; + +TEST_F(InstructionFusionTest, FuseInstructions) { + auto module = tools::Parse(R"( + HloModule test_module + ENTRY entry_computation { + p0 = f32[4,3]{1,0} parameter(0) + add = f32[4,3]{1,0} add(p0, p0) + ROOT sub = f32[4,3]{1,0} subtract(add, p0) + })") + .ValueOrDie(); + HloInstruction* sub = module->entry_computation()->root_instruction(); + HloInstruction* add = sub->mutable_operand(0); + HloInstruction* fusion = + InstructionFusionForTesting(module.get()).Fuse(add, sub); + + ASSERT_THAT(fusion, op::Fusion()) << module->ToString(); + EXPECT_THAT(fusion->fused_expression_root(), + op::Subtract(op::Add(), op::Parameter())) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) { + auto module = tools::Parse(R"( + HloModule test_module + fused_computation { + p1 = f32[4,3] parameter(0) + add = f32[4,3] add(p1, p1) + } + ENTRY entry_computation { + p0 = f32[4,3] parameter(0) + abs = f32[4,3] abs(p0) + ROOT fusion = f32[4,3] fusion(abs), kind=kLoop, calls=fused_computation + })") + .ValueOrDie(); + HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* abs = root->mutable_operand(0); + HloInstruction* fusion = + InstructionFusionForTesting(module.get()).Fuse(abs, root); + + ASSERT_THAT(fusion, op::Fusion()) << module->ToString(); + EXPECT_THAT(fusion->fused_expression_root(), op::Add(op::Abs(), op::Abs())) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, FuseInstructionsIntoMultiOutput) { + auto module = tools::Parse(R"( + HloModule test_module + ENTRY entry_computation { + p0 = f32[4,3]{1,0} parameter(0) + abs = f32[4,3]{1,0} abs(p0) + tanh = f32[4,3]{1,0} tanh(abs) + ROOT add = f32[4,3]{1,0} add(abs, tanh) + })") + .ValueOrDie(); + HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* abs = root->mutable_operand(0); + HloInstruction* tanh = root->mutable_operand(1); + HloInstruction* fusion = + InstructionFusionForTesting(module.get()).FuseIntoMultiOutput(abs, tanh); + + ASSERT_THAT(fusion, op::Fusion()) << module->ToString(); + EXPECT_THAT(fusion->fused_expression_root(), op::Tuple(op::Tanh(), op::Abs())) + << module->ToString(); +} + TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction( @@ -90,7 +177,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { EXPECT_FALSE( InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) .Run(module.get()) - .ValueOrDie()); + .ValueOrDie()) + << module->ToString(); } // Counts the number of HLO ops with a given op code in the specified module. @@ -124,7 +212,7 @@ TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) { EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString(); // Make sure the add hasn't been duplicated. - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString(); + EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); } TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { @@ -149,7 +237,11 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { .Run(module.get()) .ValueOrDie()) << module->ToString(); - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Subtract(op::Abs(op::Parameter()), op::Parameter())) + << module->ToString(); // Make sure the add hasn't been duplicated. EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); @@ -242,7 +334,12 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { .Run(module.get()) .ValueOrDie()) << module->ToString(); - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString(); + root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Tuple(op::Subtract(op::Parameter(), op::Parameter()), + op::Subtract(op::Parameter(), op::Parameter()))) + << module->ToString(); // Make sure we didn't duplicate any adds. EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString(); @@ -291,4 +388,29 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { .ValueOrDie()); } +TEST_F(InstructionFusionTest, + WideningConvertsAreAlwaysDuplicableIntoConsumers) { + auto module = tools::Parse(R"( + HloModule test_module + ENTRY Test { + p0 = f16[100] parameter(0) + c = f32[100] convert(p0) + add = f32[100] add(c, c) + ROOT mul = f32[100] multiply(c, c) + })") + .ValueOrDie(); + + // The convert should be fused into the add and mul, even though may_duplicate + // is false, because it's always beneficial to fuse/duplicate widening + // converts into consumers. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/false) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion(op::Parameter())); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 45505484951abfcee93a62fec7a99e86cbb9150c..524d3234eb4eff9c7d000eca1a0d9f5c4fae90af 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -18,7 +18,6 @@ cc_library( "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service/interpreter:platform_id", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -117,6 +116,5 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_headers_lib", - "//tensorflow/core:stream_executor_no_cuda", ], ) diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index eecbbcb93df64b09acb5e009d3db79e51dab0c93..c1666530687f2f8407a9dcb4e271c9d95552a689 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/types.h" @@ -45,8 +44,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); pipeline.AddPass( - hlo_module->device_entry_computation_layout()); - + hlo_module->mutable_device_entry_computation_layout()); return pipeline.Run(hlo_module).status(); } @@ -71,7 +69,8 @@ StatusOr> InterpreterCompiler::RunBackend( // Create executable from only the Hlo module. std::unique_ptr executable = - xla::MakeUnique(std::move(hlo_module)); + xla::MakeUnique(std::move(hlo_module), + xla::MakeUnique()); return std::move(executable); } @@ -101,17 +100,14 @@ HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction() return InterpreterExecutable::ShapeSizeBytes; } -static std::unique_ptr CreateComputationPlacer() { - return xla::MakeUnique(); -} - static bool InitModule() { xla::Compiler::RegisterCompilerFactory( se::interpreter::kXlaInterpreterPlatformId, []() { return xla::MakeUnique(); }); xla::ComputationPlacer::RegisterComputationPlacer( - se::interpreter::kXlaInterpreterPlatformId, &CreateComputationPlacer); + se::interpreter::kXlaInterpreterPlatformId, + []() { return xla::MakeUnique(); }); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 61f199bc9e8f4f95a2f097af4abf9395a1e05f64..029e71058a7373b9310c6d9ffdb65f72ca28e5af 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -32,16 +31,17 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace interpreter { InterpreterExecutable::InterpreterExecutable( - std::unique_ptr hlo_module) + std::unique_ptr hlo_module, + std::unique_ptr evaluator) : Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr, - /*hlo_profile_index_map=*/nullptr) {} + /*hlo_profile_index_map=*/nullptr), + evaluator_(std::move(evaluator)) {} InterpreterExecutable::~InterpreterExecutable() {} @@ -82,10 +82,13 @@ StatusOr InterpreterExecutable::ExecuteOnStream( } // Execute the graph using the HloEvaluator. - HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN( - std::unique_ptr result_literal, - evaluator.Evaluate>(*computation, arg_literals)); + std::unique_ptr result_literal; + { + tensorflow::mutex_lock lock(evaluator_lock_); + TF_ASSIGN_OR_RETURN(result_literal, + evaluator_->Evaluate>( + *computation, arg_literals)); + } // Transform the result literal back into a ShapedBuffer. TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index b0b797ca7d6f449a11c662ffba7c2a0a0040e47e..91d8148d26dc8eddbafdaf4870d9efbb73a12816 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" @@ -30,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -40,13 +42,15 @@ namespace interpreter { // buffer allocation. Refer to interpreter/README.md for more. class InterpreterExecutable : public Executable { public: - InterpreterExecutable(std::unique_ptr hlo_module); + InterpreterExecutable(std::unique_ptr hlo_module, + std::unique_ptr evaluator); ~InterpreterExecutable() override; StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) override; + HloExecutionProfile* hlo_execution_profile) override + LOCKS_EXCLUDED(evaluator_lock_); StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, @@ -54,6 +58,11 @@ class InterpreterExecutable : public Executable { static int64 ShapeSizeBytes(const Shape& shape); + protected: + // The interpreter interprets executables with an HloEvaluator. + std::unique_ptr evaluator_ PT_GUARDED_BY(evaluator_lock_); + mutable tensorflow::mutex evaluator_lock_; + private: TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable); }; diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index 92e069a8c67c1d441ba9d396dee503c9b3bde0df..42c2c28997d5f3b02f1fe4effca164c893e4071d 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/interpreter/executor.h" -#include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/stream_executor/device_options.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/ptr_util.h" @@ -31,13 +30,13 @@ limitations under the License. namespace stream_executor { namespace interpreter { -XlaInterpreterPlatform::XlaInterpreterPlatform() : name_("Interpreter") {} +XlaInterpreterPlatform::XlaInterpreterPlatform(const string& name, + const Platform::Id& id) + : name_(name), id_(id) {} XlaInterpreterPlatform::~XlaInterpreterPlatform() {} -Platform::Id XlaInterpreterPlatform::id() const { - return kXlaInterpreterPlatformId; -} +Platform::Id XlaInterpreterPlatform::id() const { return id_; } int XlaInterpreterPlatform::VisibleDeviceCount() const { return 1; } @@ -106,8 +105,6 @@ REGISTER_MODULE_INITIALIZER( interpreter_platform, stream_executor::interpreter::InitializeXlaInterpreterPlatform()); -DECLARE_MODULE_INITIALIZER(multi_platform_manager); - // Note that module initialization sequencing is not supported in the // open-source project, so this will be a no-op there. REGISTER_MODULE_INITIALIZER_SEQUENCE(interpreter_platform, diff --git a/tensorflow/compiler/xla/service/interpreter/platform.h b/tensorflow/compiler/xla/service/interpreter/platform.h index d68c5aa20dda7ac246ed4aa667851e385a604c04..0187f6d473b19f50136e214708e56f833627d9d1 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.h +++ b/tensorflow/compiler/xla/service/interpreter/platform.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/stream_executor/executor_cache.h" #include "tensorflow/stream_executor/plugin.h" #include "tensorflow/stream_executor/stream_executor.h" @@ -28,7 +29,8 @@ namespace interpreter { class XlaInterpreterPlatform : public Platform { public: - XlaInterpreterPlatform(); + XlaInterpreterPlatform(const string& name = "Interpreter", + const Platform::Id& id = kXlaInterpreterPlatformId); ~XlaInterpreterPlatform() override; Platform::Id id() const override; @@ -55,6 +57,8 @@ class XlaInterpreterPlatform : public Platform { private: // This platform's name. string name_; + // This platform's id. + Platform::Id id_; // Cache of created StreamExecutors. ExecutorCache executor_cache_; diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index cfa7ba5e81ddd003978a2bd763384581c55b5c83..7067b6f86a0fb24fb946ad236bca9bbd48d53722 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -31,10 +31,12 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -400,9 +402,9 @@ string LayoutConstraints::ToString() const { } Status LayoutAssignment::AddMandatoryConstraints( - const ComputationLayout& computation_layout, - const ChannelLayoutConstraints* channel_constraints, - HloComputation* computation, LayoutConstraints* constraints) { + const ComputationLayout* computation_layout, + ChannelLayoutConstraints* channel_constraints, HloComputation* computation, + LayoutConstraints* constraints) { VLOG(3) << "Adding mandatory layout constraints to computation " << computation->name(); @@ -424,11 +426,16 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR(constraints->SetOperandLayout( instruction->outfeed_shape(), instruction, 0)); } else if (instruction->opcode() == HloOpcode::kParameter) { - // Parameter layouts must match the respective layout in - // ComputationLayout. - shape_with_layout = - &computation_layout.parameter_layout(instruction->parameter_number()) - .shape(); + if (computation_layout != nullptr) { + const ShapeLayout& parameter_layout = + computation_layout->parameter_layout( + instruction->parameter_number()); + if (parameter_layout.LayoutIsSet()) { + // Parameter layouts must match the respective layout in + // ComputationLayout, if there is one. + shape_with_layout = ¶meter_layout.shape(); + } + } } if (shape_with_layout != nullptr) { TF_RETURN_IF_ERROR( @@ -493,9 +500,8 @@ Status LayoutAssignment::AddMandatoryConstraints( HloComputation* body = instruction->while_body(); HloComputation* condition = instruction->while_condition(); const HloInstruction* init = instruction->operand(0); - const ComputationLayout& body_layout = - FindOrDie(computation_layouts_, body); - const ComputationLayout& condition_layout = + ComputationLayout& body_layout = FindOrDie(computation_layouts_, body); + ComputationLayout& condition_layout = FindOrDie(computation_layouts_, condition); // Check a few invariants irrespective of layout. @@ -508,26 +514,19 @@ Status LayoutAssignment::AddMandatoryConstraints( condition_layout.parameter_shape(0))); DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape())); - // Return error if earlier layout assignment of the embedded computations - // has produced conflicting layouts. - if (!ShapeUtil::Equal(body_layout.result_shape(), - body_layout.parameter_shape(0))) { - return InternalError( - "Parameter and result of body computation %s of while instruction " - "%s have different layouts: %s vs %s", - body->name().c_str(), instruction->name().c_str(), - ShapeUtil::HumanString(body_layout.result_shape()).c_str(), - ShapeUtil::HumanString(body_layout.parameter_shape(0)).c_str()); + if (body_layout.result_layout() != body_layout.parameter_layout(0)) { + VLOG(2) << "Reset %while body parameter layout: body=" << body->name() + << " while=" << instruction->name() + << " shape=" << body_layout.result_layout().ToString(); + *body_layout.mutable_parameter_layout(0) = body_layout.result_layout(); } - if (!ShapeUtil::Equal(body->root_instruction()->shape(), - condition->parameter_instruction(0)->shape())) { - return InternalError( - "Parameter of condition computation %s of while instruction " - "%s does not match body computation %s result: %s vs %s", - condition->name().c_str(), instruction->name().c_str(), - body->name().c_str(), - ShapeUtil::HumanString(condition_layout.parameter_shape(0)).c_str(), - ShapeUtil::HumanString(body_layout.result_shape()).c_str()); + if (condition_layout.parameter_layout(0) != + body_layout.parameter_layout(0)) { + VLOG(2) << "Reset %while condition parameter layout: cond=" + << condition->name() << " while=" << instruction->name() + << " shape=" << body_layout.parameter_layout(0).ToString(); + *condition_layout.mutable_parameter_layout(0) = + body_layout.parameter_layout(0); } // Constrain the output and the operand of the while instruction to match @@ -557,7 +556,20 @@ Status LayoutAssignment::AddMandatoryConstraints( true_computation_layout.parameter_shape(0))); DCHECK(ShapeUtil::Compatible( false_operand->shape(), false_computation_layout.parameter_shape(0))); - + if (true_computation_layout.result_layout() != + false_computation_layout.result_layout()) { + // We assign layouts in DFS fashion, so the true and false computations + // might have negotiated a different layout. But for the conditional + // instruction POV the layout must match, so we run again on the false + // computation, this time with proper computation layout. + VLOG(2) << "Reset %conditional false computation result layout: " + "false_computation=" + << false_computation->name() + << " conditional=" << instruction->name() << " shape=" + << true_computation_layout.result_layout().ToString(); + *false_computation_layout.mutable_result_layout() = + true_computation_layout.result_layout(); + } TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( true_computation_layout.result_shape(), instruction)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( @@ -593,10 +605,14 @@ Status LayoutAssignment::AddMandatoryConstraints( } } } - - // Finally set the result layout to match ComputationLayout. - return constraints->SetResultLayout( - computation_layout.result_layout().shape()); + // Finally set the result layout to match ComputationLayout, if there is one. + if (computation_layout != nullptr) { + const ShapeLayout& result_layout = computation_layout->result_layout(); + if (result_layout.LayoutIsSet()) { + TF_RETURN_IF_ERROR(constraints->SetResultLayout(result_layout.shape())); + } + } + return Status::OK(); } namespace { @@ -760,6 +776,7 @@ StatusOr LayoutAssignment::CreateCopyWithNewLayout( HloInstruction* copy = instruction->parent()->AddInstruction(HloInstruction::CreateUnary( instruction->shape(), HloOpcode::kCopy, instruction)); + RegisterAddedCopy(copy); SetupCopiedInstruction(*instruction, copy, {}); LayoutUtil::ClearLayout(copy->mutable_shape()); TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( @@ -783,13 +800,19 @@ Status LayoutAssignment::CopyOperandIfLayoutsDiffer( TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { + VLOG(5) << "Operand " << operand->ToString() << " layout matches in " + << instruction->ToString(); // Operand layout already matches our constraint. Nothing to do. return Status::OK(); } + VLOG(4) << "Operand " << operand->ToString() << " layout does not match " + << operand_layout.ToString() << " in " << instruction->ToString(); TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy, CreateCopyWithNewLayout(operand_layout.shape(), operand)); + VLOG(4) << "New copy of " << operand->ToString() << " is " + << operand_copy->ToString(); return instruction->ReplaceOperandWith(operand_no, operand_copy); } @@ -896,32 +919,31 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { } } } - - // Finally verify the result layout matches the layout of the entry + // Finally verify the result layout, if set, matches the layout of the entry // computation root. - TF_RET_CHECK(ShapeUtil::Equal( - module->entry_computation()->root_instruction()->shape(), + const ShapeLayout& result_layout = FindOrDie(computation_layouts_, module->entry_computation()) - .result_layout() - .shape())); - + .result_layout(); + if (result_layout.LayoutIsSet()) { + TF_RET_CHECK(ShapeUtil::Equal( + module->entry_computation()->root_instruction()->shape(), + result_layout.shape())); + } return Status::OK(); } LayoutAssignment::LayoutAssignment( - const ComputationLayout& entry_computation_layout, + ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints) : entry_computation_layout_(entry_computation_layout), channel_layout_constraints_(channel_constraints) { - VLOG(1) << "entry computation layout given to layout assignment: " - << entry_computation_layout_.ToString(); + VLOG(1) << "Entry computation layout given to layout assignment: " + << entry_computation_layout_->ToString(); // Layouts of all parameter instructions must be set. for (const ShapeLayout& parameter_layout : - entry_computation_layout_.parameter_layouts()) { + entry_computation_layout_->parameter_layouts()) { CHECK(parameter_layout.LayoutIsSet()); } - // TODO(b/29118294): Choose a better layout if the result layout is not set. - CHECK(entry_computation_layout_.result_layout().LayoutIsSet()); } std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( @@ -1481,16 +1503,60 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, return Status::OK(); } +Status LayoutAssignment::CalculateComputationLayout( + HloComputation* computation) { + ComputationLayout computation_layout(computation->ComputeProgramShape(), + /*ignore_layouts=*/false); + InsertOrDie(&computation_layouts_, computation, computation_layout); + VLOG(2) << " Calculated ComputationLayout = " + << computation_layout.ToString(); + return Status::OK(); +} + +Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { + // Clear existing layouts of the instructions. All layouts must be assigned + // by the LayoutAssignment pass, except for those on infeeds, parameters, + // and the computation result. The latter two are specified in + // computation_layout, so we only need to keep the existing layouts for + // infeeds. Clearing the layouts here avoids hiding potential bugs in the + // layout assignment pass that may accidently use the existing layout. + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kBitcast) { + // bitcasts are inherently layout sensitive and so a bitcast instruction + // present in the IR before layout assignment is a bug. + return InternalError( + "Unexpected bitcast operation seen during layout assignment: %s.", + instruction->ToString().c_str()); + } + if (instruction->opcode() != HloOpcode::kInfeed) { + LayoutUtil::ClearLayout(instruction->mutable_shape()); + } + } + return Status::OK(); +} + Status LayoutAssignment::RunOnComputation( - const ComputationLayout& computation_layout, + ComputationLayout* computation_layout, const TuplePointsToAnalysis& points_to_analysis, HloComputation* computation, ChannelLayoutConstraints* channel_constraints) { - DCHECK(computation_layout.LayoutIsSet()); - InsertOrDie(&computation_layouts_, computation, computation_layout); VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name() << ")"; - VLOG(2) << " ComputationLayout = " << computation_layout.ToString(); + TF_RETURN_IF_ERROR(ClearComputationLayouts(computation)); + if (computation_layout != nullptr) { + auto it = computation_layouts_.find(computation); + if (it == computation_layouts_.end()) { + VLOG(2) << " New ComputationLayout = " << computation_layout->ToString(); + computation_layouts_.emplace(computation, *computation_layout); + } else { + TF_RET_CHECK(computation_layout == &it->second || + computation_layout == entry_computation_layout_); + VLOG(2) << " Existing ComputationLayout = " + << computation_layout->ToString(); + } + } else { + VLOG(2) << " No ComputationLayout specified (will be calculated)"; + } // Construct LayoutConstraints with all layout constraints of the computation. LayoutConstraints constraints(points_to_analysis, computation); @@ -1533,12 +1599,19 @@ Status LayoutAssignment::RunOnComputation( CHECK_LT(constraints.unconstrained_buffer_ids().size(), unconstrained_count); } - // All logical buffers should have constraints at this point. All that // remains is assign the constraints to the buffers and infer layouts for // aliased buffers. TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation)); + // If the computation layout wasn't specified, now it is the time to compute + // it according to the parameters and root instruction layouts. + // This allows the first pass through this API to record the best flowing + // layout to parameters and root instruction. + if (computation_layout == nullptr) { + TF_RETURN_IF_ERROR(CalculateComputationLayout(computation)); + } + // Record the layouts assigned for any communication ops in // channel_constraints so that they are constrained for future modules. for (HloInstruction* instruction : computation->instructions()) { @@ -1553,6 +1626,34 @@ Status LayoutAssignment::RunOnComputation( return Status::OK(); } +Status LayoutAssignment::PropagateComputationLayouts( + HloComputation* computation, ComputationLayout* computation_layout) { + ComputationLayout computed_computation_layout( + computation->ComputeProgramShape(), + /*ignore_layouts=*/false); + for (int64 i = 0; i < computed_computation_layout.parameter_count(); ++i) { + ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i); + if (!param_layout->LayoutIsSet()) { + VLOG(4) << "Assigning layout to parameter " << i << " of computation " + << computation->name() << ": " + << computed_computation_layout.parameter_layout(i).ToString(); + *param_layout = computed_computation_layout.parameter_layout(i); + } else { + TF_RET_CHECK(computed_computation_layout.parameter_layout(i) == + *param_layout); + } + } + ShapeLayout* result_layout = computation_layout->mutable_result_layout(); + if (!result_layout->LayoutIsSet()) { + VLOG(4) << "Assigning result layout of computation " << computation->name() + << ": " << computed_computation_layout.result_layout().ToString(); + *result_layout = computed_computation_layout.result_layout(); + } else { + TF_RET_CHECK(computed_computation_layout.result_layout() == *result_layout); + } + return Status::OK(); +} + StatusOr LayoutAssignment::Run(HloModule* module) { VLOG(2) << "Running layout assignment on module " << module->name(); XLA_VLOG_LINES(3, module->ToString()); @@ -1561,52 +1662,45 @@ StatusOr LayoutAssignment::Run(HloModule* module) { "before layout assignment", module->config().debug_options()); } - - TF_ASSIGN_OR_RETURN(auto points_to_analysis, - TuplePointsToAnalysis::Run(module)); - - // Assign layouts to computations in an order such that a callee computation - // is handled before its caller computation. This ensures that the layout of - // all callers of a computation will agree. - std::list computation_post_order = - module->MakeComputationPostOrder(); - for (auto* computation : module->MakeComputationPostOrder()) { - if (computation->IsFusionComputation()) { - continue; - } - // Clear existing layouts of the instructions. All layouts must be assigned - // by the LayoutAssignment pass, except for those on infeeds, parameters, - // and the computation result. The latter two are specified in - // computation_layout, so we only need to keep the existing layouts for - // infeeds. Clearing the layouts here avoids hiding potential bugs in the - // layout assignment pass that may accidently use the existing layout. - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kBitcast) { - // bitcasts are inherently layout sensitive and so a bitcast instruction - // present in the IR before layout assignment is a bug. - return InternalError( - "Unexpected bitcast operation seen during layout assignment: %s.", - instruction->ToString().c_str()); + TF_RETURN_IF_ERROR(Init()); + + // We do two passes. The first one we pass a nullptr ComputationLayout to + // the RunOnComputation() calls (for non entry computations), and we register + // the ComputationLayout which are naturally flowing in DFS fashion to the + // parameters and root instruction. + // Walking in DFS mode though, means that we can end up with incorrect layouts + // when seen from an outer instruction, which has across-computation + // constraints to impose. + // For example, the kWhile instruction needs to enforce the same layouts for + // the parameters and root of the bosy, as well as the condition parameters. + // Similarly, the kConditional instruction needs to enforce the same layouts + // for the root of the true and false computations. + // So in the first pass, while allowing the layouts to flow to parameters and + // root, we also fix up the eventually inconsistent ComputationLayout, which + // will be then made mandatory by the second pass. + for (int64 i = 0; i < 2; ++i) { + TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module)); + TF_ASSIGN_OR_RETURN(auto points_to_analysis, + TuplePointsToAnalysis::Run(module)); + for (auto* computation : module->MakeComputationPostOrder()) { + if (computation->IsFusionComputation()) { + continue; } - if (instruction->opcode() != HloOpcode::kInfeed) { - LayoutUtil::ClearLayout(instruction->mutable_shape()); + if (computation == module->entry_computation()) { + TF_RETURN_IF_ERROR(RunOnComputation( + entry_computation_layout_, *points_to_analysis, + module->entry_computation(), channel_layout_constraints_)); + } else { + ComputationLayout* computation_layout = + (i == 0) ? nullptr : &FindOrDie(computation_layouts_, computation); + TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, + *points_to_analysis, computation, + channel_layout_constraints_)); } } - if (computation == module->entry_computation()) { - TF_RETURN_IF_ERROR(RunOnComputation( - entry_computation_layout_, *points_to_analysis, - module->entry_computation(), channel_layout_constraints_)); - } else { - ComputationLayout computation_layout(computation->ComputeProgramShape()); - // Setting all embedded computations to the default layout is potentially - // suboptimal. - computation_layout.SetToDefaultLayout(); - TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, - *points_to_analysis, computation, - channel_layout_constraints_)); - } } - + TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(), + entry_computation_layout_)); TF_RETURN_IF_ERROR(CheckLayouts(module)); VLOG(3) << "After layout assignment:"; @@ -1616,9 +1710,54 @@ StatusOr LayoutAssignment::Run(HloModule* module) { "after layout assignment", module->config().debug_options()); } - // All layouts are reset then reassigned by this pass. return true; } +Status LayoutAssignment::Init() { + computation_layouts_.clear(); + return Status::OK(); +} + +Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) { + // Clear all the copies which have been added, and all the related + // instructions (like GTE and tuples). + int64 removed_copies = 0; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kCopy && + added_copies_.count(instruction) > 0) { + VLOG(5) << "Removing added copy: " << instruction->ToString(); + TF_RETURN_IF_ERROR( + instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); + ++removed_copies; + } + } + } + added_copies_.clear(); + if (removed_copies > 0) { + TupleSimplifier tuple_simplifier; + HloDCE dce; + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + TF_RETURN_IF_ERROR(dce.Run(module).status()); + } + return Status::OK(); +} + +Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction, + int64 operand_number) { + HloInstruction* operand = instruction->mutable_operand(operand_number); + if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) { + HloInstruction* copy = + instruction->parent()->AddInstruction(HloInstruction::CreateUnary( + operand->shape(), HloOpcode::kCopy, operand)); + SetupCopiedInstruction(*operand, copy, {}); + LayoutUtil::ClearLayout(copy->mutable_shape()); + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy)); + } + return Status::OK(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 9663a793fdd7d4968700707a1003319e89ea19a3..c287cca0c54ba1bb514bd8d243c137eca99b258f 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -288,7 +289,7 @@ class LayoutAssignment : public HloPassInterface { // If channel_constraints is nullptr, no kSend or kRecvs must be contained // within any module passed to `Run`. explicit LayoutAssignment( - const ComputationLayout& entry_computation_layout, + ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints = nullptr); ~LayoutAssignment() override {} tensorflow::StringPiece name() const override { return "layout-assignment"; } @@ -362,12 +363,15 @@ class LayoutAssignment : public HloPassInterface { int64 operand_no); private: + // Initializes the layout assignment object for a new Run() call. + Status Init(); + // Adds constraints which must be satisfied for correctness on all // backends. Called once prior to propagating constraints. - Status AddMandatoryConstraints( - const ComputationLayout& computation_layout, - const ChannelLayoutConstraints* channel_constraints, - HloComputation* computation, LayoutConstraints* constraints); + Status AddMandatoryConstraints(const ComputationLayout* computation_layout, + ChannelLayoutConstraints* channel_constraints, + HloComputation* computation, + LayoutConstraints* constraints); // This method can be overridden to add backend-specific constraints to the // layout of the instructions of a computation. This method is called after @@ -378,10 +382,12 @@ class LayoutAssignment : public HloPassInterface { } // Construct contraints and assign layouts to all instructions in the - // computation satisfying the given ComputationLayout. Layouts constraints are - // added, then propagated until all LogicalBuffers in the computation are - // constrained. - Status RunOnComputation(const ComputationLayout& computation_layout, + // computation satisfying the given ComputationLayout, if not nullptr. + // Otherwise the ComputationLayout will be calculated by propagating the + // computation instruction contraints. + // Layouts constraints are added, then propagated until all LogicalBuffers in + // the computation are constrained. + Status RunOnComputation(ComputationLayout* computation_layout, const TuplePointsToAnalysis& points_to_analysis, HloComputation* computation, ChannelLayoutConstraints* channel_constraints); @@ -402,7 +408,26 @@ class LayoutAssignment : public HloPassInterface { // necessary conditions. Status CheckLayouts(HloModule* module); - const ComputationLayout& entry_computation_layout_; + // Computes the ComputationLayout of the given computation based of the + // layouts assigned to parameters and root instruction, and inserts it to the + // computation_layouts_ map. + Status CalculateComputationLayout(HloComputation* computation); + + // Clears all the layouts which can be cleared within a computation. + Status ClearComputationLayouts(HloComputation* computation); + + // Clears the side effects of a previous pass, like added copy instructions. + Status ClearPreviousPassSideEffects(HloModule* module); + + // Propagates the layouts computed by the layout assignment pass on the given + // computation, to the computation layout passed in to this API. + // This API propagates missing layout, and also checks that the caller + // specified have been respected, by comparing those with the parameters and + // root computation instruction. + Status PropagateComputationLayouts(HloComputation* computation, + ComputationLayout* computation_layout); + + ComputationLayout* entry_computation_layout_; protected: // Sets up the copy instruction according to the characteristic (sharding, @@ -418,21 +443,37 @@ class LayoutAssignment : public HloPassInterface { // Creates and returns a copy of the given instruction with a different // layout. Tuple-shaped instructions will be deep-copied, and the last Tuple // instruction producing the copy is returned. - static StatusOr CreateCopyWithNewLayout( + StatusOr CreateCopyWithNewLayout( const Shape& shape_with_layout, HloInstruction* instruction); // Creates a copy of the given operand if the operand's layout does not match // the given layout. This copy replaces the use in the given instruction. // Tuple operands will be deep-copied. - static Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, - HloInstruction* instruction, - int64 operand_no); + Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, + HloInstruction* instruction, + int64 operand_no); + + // Registers a copy instruction added by the layout assignment pass. + void RegisterAddedCopy(HloInstruction* copy) { + CHECK_EQ(copy->opcode(), HloOpcode::kCopy); + added_copies_.insert(copy); + } + + // Adds a copy for the operand of an instruction, unless such operand is + // already a copy, and has a single user (which is forcibly the instruction + // itself). + Status AddCopyForOperand(HloInstruction* instruction, int64 operand_number); // Map containing the layouts of all computations assigned so // far. Computations are handled in a topological sort where computations are // handled before their caller instructions so the layouts of caller // instructions can be set to match the computation. std::map computation_layouts_; + + // Every copy added to the module by the layout assignment pass is registered + // here. + tensorflow::gtl::FlatSet added_copies_; + ChannelLayoutConstraints* channel_layout_constraints_; }; diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 7e1bb11eaada0e62b82c50903c9848f0a3a8307b..7508013199a82267efc0e1426cb5989d5fe844a0 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -53,7 +53,7 @@ class LayoutAssignmentTest : public HloTestBase { protected: void AssignLayouts(HloModule* module, ComputationLayout* entry_computation_layout) { - LayoutAssignment layout_assignment(*entry_computation_layout); + LayoutAssignment layout_assignment(entry_computation_layout); EXPECT_IS_OK(layout_assignment.Run(module).status()); } }; @@ -285,7 +285,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape( result_shape)); - LayoutAssignment layout_assignment(computation_layout); + LayoutAssignment layout_assignment(&computation_layout); AssignLayouts(module.get(), &computation_layout); // Layout assignment should have deep copied the result of the computation to @@ -488,7 +488,7 @@ class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment { public: explicit OperandsMustBeTheSameLayoutAssignment( ComputationLayout* entry_computation_layout) - : LayoutAssignment(*entry_computation_layout) {} + : LayoutAssignment(entry_computation_layout) {} protected: Status PropagateBufferConstraint( @@ -660,13 +660,12 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { /*device_allocator=*/nullptr) .ConsumeValueOrDie(); - EXPECT_EQ( - ::tensorflow::Status::OK(), - backend() - .compiler() - ->RunBackend(std::move(module), backend().default_stream_executor(), - /*device_allocator=*/nullptr) - .status()); + EXPECT_EQ(Status::OK(), backend() + .compiler() + ->RunBackend(std::move(module), + backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .status()); } // A GTE inside of a fusion node inherits the layout of its operand (which @@ -808,7 +807,7 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); - LayoutAssignment layout_assignment(computation_layout); + LayoutAssignment layout_assignment(&computation_layout); Status error_status = layout_assignment.Run(module.get()).status(); EXPECT_FALSE(error_status.ok()); EXPECT_THAT( diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc deleted file mode 100644 index 68c99256a246edcf43a8358f667fc4458b9b4fea..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ /dev/null @@ -1,379 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/liveness_util.h" - -#include -#include -#include - -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" - -namespace xla { - -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const TuplePointsToAnalysis& points_to_analysis) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { - // GetTupleElement instructions only access the top-level buffer of their - // operand. - return true; - } else if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop) { - // Find fusion parameter associated with 'operand'. - auto it = std::find_if( - user->fused_parameters().begin(), user->fused_parameters().end(), - [=](HloInstruction* fused_param) { - return user->operand(fused_param->parameter_number()) == operand; - }); - CHECK(it != user->fused_parameters().end()); - // Iterate through all users of all buffer aliases of the buffer in the - // points-to set of fusion parameter at 'index'. - // Return false if any uses are detected at 'index', returns true otherwise. - const LogicalBuffer* buffer = - points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie(); - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - for (HloInstruction* alias_user : alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), - alias_user, points_to_analysis)) { - continue; - } - // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'. - return false; - } - } - // Return true: found no uses of 'operand' at 'index' in 'user'. - return true; - } - return false; -} - -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const HloDataflowAnalysis& dataflow) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop) { - // Find fusion parameter associated with 'operand'. - HloInstruction* fusion_param = - user->fused_parameter(user->operand_index(operand)); - // Iterate through all users of all uses of the fusion parameter value. - // Return false if any uses are detected, returns true otherwise. - const HloValue& value = dataflow.GetValueDefinedAt(fusion_param, index); - return value.uses().empty(); - } else { - // Return false if no value at 'operand' and 'index' is used at 'user'. - for (const HloValue* value : - dataflow.GetValueSet(operand, index).values()) { - for (const HloUse& use : value->uses()) { - if (use.instruction == user) { - return false; - } - } - } - } - - return true; -} - -namespace { - -// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. -// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) -// where 'user' is a user of an alias of 'instruction' at 'index', and -// 'operand_index' is the operand index at which the alias appears in the -// operand list of 'user'. -std::vector> GetAllUsesOfInstructionAtIndex( - HloInstruction* instruction, const ShapeIndex& index, - const TuplePointsToAnalysis& points_to_analysis) { - std::vector> uses; - const PointsToSet::BufferList& points_to = - points_to_analysis.GetPointsToSet(instruction).element(index); - for (const LogicalBuffer* buffer : points_to) { - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - for (HloInstruction* alias_user : alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), - alias_user, points_to_analysis)) { - continue; - } - for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) { - uses.emplace_back(alias_user, op_idx); - } - } - } - } - return uses; -} - -// Returns true if there is exactly one use of 'operand' at 'operand_index' -// in 'fusion.fused_instructions', where the singleton use is the fused -// root at operand index 'use_operand_index'. Returns false otherwise. -// -// REQUIRES: 'fusion' opcode is a kFusion instruction. -bool HasUniqueFusedUseOfOperandAt( - HloInstruction* operand, const ShapeIndex& operand_index, - HloInstruction* fusion, const int64 use_operand_index, - const TuplePointsToAnalysis& points_to_analysis) { - CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); - // Check that 'operand' is unique in the operand list of 'fusion'. - if (fusion->OperandIndices(operand).size() > 1) { - return false; - } - // Find fusion parameter associated with 'operand'. - const auto& fused_params = fusion->fused_parameters(); - auto fused_param_it = std::find_if( - fused_params.begin(), fused_params.end(), - [&](HloInstruction* fused_param) { - return fusion->operand(fused_param->parameter_number()) == operand; - }); - if (fused_param_it == fused_params.end()) { - return false; - } - auto* fused_param = *fused_param_it; - // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'. - auto fused_param_uses = GetAllUsesOfInstructionAtIndex( - fused_param, operand_index, points_to_analysis); - // Return true iff there is exactly one use of 'operand' at 'index', and - // this singleton use is the fused root (at index in 'use_operand_indices'). - return fused_param_uses.size() == 1 && - fused_param_uses[0].first == fusion->fused_expression_root() && - fused_param_uses[0].second == use_operand_index; -} - -} // namespace - -// User and operand can share buffers iff both instructions emit the same shape -// and layout, and 'user' meets one of the following qualifications: -// -// (1) Is element-wise. Or... -// (2) Is a loop fusion instruction where the only use of 'operand' at 'index' -// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root -// at operand 0. Or... -// (3) Is a kDot -> kAdd (or fused kTransposeDot -> kAdd) output fusion -// instruction where the only use of 'operand' at 'index' in the set -// 'user.fused_instructions' is a kAdd fused root at operand 0 or 1. Or... -// (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index -// 0. -// -// (2) and (3) can only be determined if points-to analysis is available. -bool CanShareOperandBufferWithUser( - HloInstruction* operand, const ShapeIndex& operand_index, - HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis& points_to_analysis) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - const Shape& operand_subshape = - ShapeUtil::GetSubshape(operand->shape(), operand_index); - const Shape& user_subshape = - ShapeUtil::GetSubshape(user->shape(), user_index); - // Check that operand and user emit the same shape and layout. - if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { - return false; - } - if (user->opcode() == HloOpcode::kFusion) { - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - // Loop fusion with kDynamicUpdateSlice fused root. - // - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root at operand - // index 0. - return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0, - points_to_analysis); - } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && - user->fused_expression_root()->opcode() == HloOpcode::kAdd) { - // Output fusion with kAdd fused root. - - // Check if one operand of kAdd fused root is either kDot, or nested - // kFusion of kind kTransposeDot. - auto* add = user->fused_expression_root(); - auto add_operand_it = - std::find_if(add->operands().begin(), add->operands().end(), - [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot || - (operand->opcode() == HloOpcode::kFusion && - operand->fusion_kind() == - HloInstruction::FusionKind::kTransposeDot); - }); - if (add_operand_it == add->operands().end()) { - return false; - } - auto* matched_add_operand = *add_operand_it; - // Calculate operand index of 'add' operand which was not matched above. - const int64 other_add_operand_index = - matched_add_operand == add->operand(0) ? 1 : 0; - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root (at operand - // index 'other_add_operand_index'). - return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, - other_add_operand_index, - points_to_analysis); - } - } - if (user->opcode() == HloOpcode::kDynamicUpdateSlice || - user->opcode() == HloOpcode::kWhile) { - // We eliminated other users in BufferLiveness::live_range_strictly_before, - // so here we just need to check that the use is at operand index 0. - std::vector operand_indices = user->OperandIndices(operand); - return operand_indices.size() == 1 && operand_indices[0] == 0; - } - if (user->opcode() == HloOpcode::kCall) { - // TODO(b/62548313): Remove when buffer assignment is module scoped and - // does not assign buffers to calls. - // Find called computation parameter associated with 'operand'. - const std::vector operand_indices = user->OperandIndices(operand); - if (operand_indices.size() > 1) { - return false; - } - CHECK_EQ(1, operand_indices.size()); - auto* param = user->to_apply()->parameter_instruction(operand_indices[0]); - // Get all uses of 'operand' at 'index' in called computation. - auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index, - points_to_analysis); - - // Return true iff: - // *) There exists exactly one use of 'operand' in called computation. - // *) The unique use is by the root instruction of called computation. - // (Note: we check the root of the called computation, because the - // root result buffer is required to alias with the Call result buffer). - // *) The root instruction of the called computation is element-wise on - // 'operand'. - auto* callee_root = user->to_apply()->root_instruction(); - return param_uses.size() == 1 && param_uses[0].first == callee_root && - callee_root->IsElementwiseOnOperand(param_uses[0].second); - } - // Check if 'user' is element-wise. - return user->IsElementwise(); -} - -bool CanShareOperandBufferWithUser(HloInstruction* operand, - const ShapeIndex& operand_index, - HloInstruction* user, - const ShapeIndex& user_index, - const HloDataflowAnalysis& dataflow) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - const Shape& operand_subshape = - ShapeUtil::GetSubshape(operand->shape(), operand_index); - const Shape& user_subshape = - ShapeUtil::GetSubshape(user->shape(), user_index); - // Check that operand and user emit the same shape and layout. - if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { - return false; - } - - if (user->opcode() == HloOpcode::kFusion) { - // Get the parameter associated with 'operand'; - HloInstruction* fusion_param = - user->fused_parameter(user->operand_index(operand)); - - const HloValue& value = - dataflow.GetValueDefinedAt(fusion_param, operand_index); - if (value.uses().size() != 1) { - return false; - } - const HloUse& use = value.uses()[0]; - - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - // Loop fusion with kDynamicUpdateSlice fused root. - // - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root at operand - // index 0. - return use.instruction == user->fused_expression_root() && - use.operand_number == 0; - } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && - user->fused_expression_root()->opcode() == HloOpcode::kAdd) { - // Output fusion with kAdd fused root. - - // Check if one operand of kAdd fused root is either kDot, or nested - // kFusion of kind kTransposeDot. - auto* add = user->fused_expression_root(); - auto add_operand_it = - std::find_if(add->operands().begin(), add->operands().end(), - [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot || - (operand->opcode() == HloOpcode::kFusion && - operand->fusion_kind() == - HloInstruction::FusionKind::kTransposeDot); - }); - if (add_operand_it == add->operands().end()) { - return false; - } - auto* matched_add_operand = *add_operand_it; - // Calculate operand index of 'add' operand which was not matched above. - const int64 other_add_operand_index = - matched_add_operand == add->operand(0) ? 1 : 0; - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root (at operand - // index 'other_add_operand_index'). - return use.instruction == user->fused_expression_root() && - use.operand_number == other_add_operand_index; - } - } - if (user->opcode() == HloOpcode::kDynamicUpdateSlice || - user->opcode() == HloOpcode::kWhile) { - // We eliminated other users in BufferLiveness::live_range_strictly_before, - // so here we just need to check that the use is at operand index 0. - std::vector operand_indices = user->OperandIndices(operand); - return operand_indices.size() == 1 && operand_indices[0] == 0; - } - if (user->opcode() == HloOpcode::kCall) { - // Get all uses of value defined by 'operand' at 'operand_index'. - const auto& uses = - dataflow.GetValueDefinedAt(operand, operand_index).uses(); - // Return true iff: - // *) There exists two uses of 'operand'. - // *) One use is by 'user' (caller). - // *) One use is by root instruction of called computation (callee root). - // (Note: we check the root of the called computation, because the - // root result buffer is required to alias with the Call result buffer). - // *) The root instruction of the called computation is element-wise on - // 'operand'. - const bool found_caller_use = - std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { - return use.instruction == user; - }) != uses.end(); - auto* callee_root = user->to_apply()->root_instruction(); - const bool found_elementwise_callee_use = - std::find_if( - uses.begin(), uses.end(), [callee_root](const HloUse& use) { - return use.instruction == callee_root && - callee_root->IsElementwiseOnOperand(use.operand_number); - }) != uses.end(); - return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; - } - // Check if 'user' is element-wise. - return user->IsElementwise(); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util.h b/tensorflow/compiler/xla/service/liveness_util.h deleted file mode 100644 index 28ef991880039de73cc158a67ef2a5f78fc90e6d..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/liveness_util.h +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// A collection of utilities on the HLO graph. - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ - -#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/types.h" - -namespace xla { - -// Returns true if 'user' cannot possibly use the buffer at 'index' in -// 'operand'. Returns false otherwise. -// -// REQUIRES: 'operand' is an operand of 'user'. -// -// TODO(b/65835246): Remove TuplePointsToAnalysis overload when all users have -// moved over to the dataflow overload. -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const TuplePointsToAnalysis& points_to_analysis); -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const HloDataflowAnalysis& dataflow); - -// Returns true if 'user' (at 'user_index') can share a buffer with its operand -// 'operand' (at 'operand_index'). Returns false otherwise. -// -// REQUIRES: 'operand' is an operand of 'user'. -// -// TODO(b/65835246): Remove TuplePointsToAnalysis overload when all users have -// moved over to the dataflow overload. -bool CanShareOperandBufferWithUser( - HloInstruction* operand, const ShapeIndex& operand_index, - HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis& points_to_analysis); -bool CanShareOperandBufferWithUser(HloInstruction* operand, - const ShapeIndex& operand_index, - HloInstruction* user, - const ShapeIndex& user_index, - const HloDataflowAnalysis& dataflow); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc deleted file mode 100644 index f8b309488eeb5391b1cad5db760934ec1f7e3521..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ /dev/null @@ -1,463 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/liveness_util.h" - -#include - -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" - -namespace xla { -namespace { - -class PointsToAnalysisTestBase : public HloTestBase { - protected: - void BuildModule(std::unique_ptr computation) { - module_ = CreateNewModule(); - computation_ = module_->AddEntryComputation(std::move(computation)); - } - - void RunAnalysis() { - CHECK_NOTNULL(module_.get()); - points_to_analysis_ = - TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); - dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie(); - } - - void BuildModuleAndRunAnalysis(std::unique_ptr computation) { - BuildModule(std::move(computation)); - RunAnalysis(); - } - - std::unique_ptr module_; - HloComputation* computation_ = nullptr; - std::unique_ptr points_to_analysis_; - std::unique_ptr dataflow_analysis_; -}; - -class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {}; - -TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { - auto builder = HloComputation::Builder(TestName()); - - Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); - auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1)); - builder.AddInstruction( - HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1)); - - BuildModuleAndRunAnalysis(builder.Build()); - - // GetTupleElement instructions only access the top-level buffer of their - // operand. - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *points_to_analysis_)); - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *points_to_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *points_to_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *points_to_analysis_)); - - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *dataflow_analysis_)); - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *dataflow_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *dataflow_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *dataflow_analysis_)); -} - -TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { - auto builder = HloComputation::Builder(TestName()); - - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); - - // Create a DynamicUpdateSlice instruction of tuple element 1. - auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); - auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); - auto dynamic_update_slice = - builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); - builder.AddInstruction( - HloInstruction::CreateTuple({gte0, dynamic_update_slice})); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {dynamic_update_slice, starts, update, gte1}, - HloInstruction::FusionKind::kLoop); - RunAnalysis(); - - // The fusion instruction never uses tuple element 0, but does use element 1. - EXPECT_TRUE( - DoesNotUseOperandBuffer(tuple, {0}, fusion, *points_to_analysis_)); - EXPECT_FALSE( - DoesNotUseOperandBuffer(tuple, {1}, fusion, *points_to_analysis_)); - - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, fusion, *dataflow_analysis_)); - EXPECT_FALSE( - DoesNotUseOperandBuffer(tuple, {1}, fusion, *dataflow_analysis_)); -} - -class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {}; - -TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { - auto builder = HloComputation::Builder(TestName()); - - Shape shape = ShapeUtil::MakeShape(F32, {8}); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); - auto log = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp)); - - BuildModuleAndRunAnalysis(builder.Build()); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, log, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, log, {}, *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { - auto builder = HloComputation::Builder(TestName()); - - Shape in_shape = ShapeUtil::MakeShape(F32, {8}); - Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); - auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, in_shape, "param0")); - auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, in_shape, "param1")); - auto result = builder.AddInstruction( - HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); - - BuildModuleAndRunAnalysis(builder.Build()); - - EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {}, - *points_to_analysis_)); - EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {}, - *points_to_analysis_)); - - EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {}, - *dataflow_analysis_)); - EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { - auto builder = HloComputation::Builder(TestName()); - - Shape shape = ShapeUtil::MakeShape(F32, {8}); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); - auto copy = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp)); - - BuildModuleAndRunAnalysis(builder.Build()); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, copy, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, copy, {}, *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { - auto builder = HloComputation::Builder(TestName()); - - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); - - // Create a DynamicUpdateSlice instruction of tuple element 1. - auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); - auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); - auto dynamic_update_slice = - builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); - builder.AddInstruction( - HloInstruction::CreateTuple({gte0, dynamic_update_slice})); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {dynamic_update_slice, starts, update, gte1}, - HloInstruction::FusionKind::kLoop); - RunAnalysis(); - - // The fusion instruction can share with tuple element 1. - EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {}, - *points_to_analysis_)); - EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {}, - *points_to_analysis_)); - - EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {}, - *dataflow_analysis_)); - EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { - auto builder = HloComputation::Builder(TestName()); - - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - Shape update_shape = ShapeUtil::MakeShape(F32, {4}); - Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - auto update = builder.AddInstruction( - HloInstruction::CreateParameter(1, update_shape, "update")); - auto starts = builder.AddInstruction( - HloInstruction::CreateParameter(2, starts_shape, "starts")); - auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, data, update, starts)); - - BuildModuleAndRunAnalysis(builder.Build()); - - // The DynamicUpdateSlice instruction can share with the data operand, but not - // with update or starts. - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, dus, {}, *points_to_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(update, {}, dus, {}, *points_to_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, dus, {}, *dataflow_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(update, {}, dus, {}, *dataflow_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(starts, {}, dus, {}, *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { - auto builder = HloComputation::Builder(TestName()); - Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); - - auto a = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); - auto b = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); - - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); - - auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto add_operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); - - auto add = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape, HloOpcode::kAdd, dot, add_operand)); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {add, dot}, HloInstruction::FusionKind::kOutput); - RunAnalysis(); - - // Output fused dot add should be able to share buffer with 'add_operand'. - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *points_to_analysis_)); - - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { - auto builder = HloComputation::Builder(TestName()); - Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); - - auto a = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); - auto b = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); - auto b_t = builder.AddInstruction( - HloInstruction::CreateTranspose(data_shape, b, {1, 0})); - - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); - - auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto add_operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); - - auto add = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape, HloOpcode::kAdd, dot, add_operand)); - - BuildModule(builder.Build()); - - auto nested_fusion = computation_->CreateFusionInstruction( - {dot, b_t}, HloInstruction::FusionKind::kTransposeDot); - - auto fusion = computation_->CreateFusionInstruction( - {add, nested_fusion}, HloInstruction::FusionKind::kOutput); - RunAnalysis(); - - // Output fused transpose-dot-add should be share buffer with 'add_operand'. - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *points_to_analysis_)); - - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { - auto builder = HloComputation::Builder(TestName()); - Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); - - auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); - - auto reverse = builder.AddInstruction( - HloInstruction::CreateReverse(data_shape, operand, {0, 1})); - - auto two = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); - - auto add = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {add, two, reverse}, HloInstruction::FusionKind::kOutput); - RunAnalysis(); - - // Output fused operand->reverse->add cannot alias operand buffer 'operand'. - EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, - *points_to_analysis_)); - - EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - - auto make_cond = [this, &data_shape]() { - auto builder = HloComputation::Builder(TestName() + ".Cond"); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); - return builder.Build(); - }; - - auto make_body = [this, &data_shape]() { - auto builder = HloComputation::Builder(TestName() + ".Body"); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data)); - return builder.Build(); - }; - - module_ = CreateNewModule(); - HloComputation* cond_computation = - module_->AddEmbeddedComputation(make_cond()); - HloComputation* body_computation = - module_->AddEmbeddedComputation(make_body()); - - auto builder = HloComputation::Builder(TestName()); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - auto whil = builder.AddInstruction(HloInstruction::CreateWhile( - data_shape, cond_computation, body_computation, data)); - computation_ = module_->AddEntryComputation(builder.Build()); - - RunAnalysis(); - - // The While instruction can share with the data operand. - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, whil, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, whil, {}, *dataflow_analysis_)); -} - -// Tests that Call can alias operand buffer if the only use of the operand -// in the called computation is an elementwise instruction. -TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { - Shape shape = ShapeUtil::MakeShape(F32, {8}); - // Build sub-computation with fusion root. - auto sub_builder = HloComputation::Builder(TestName() + "_sub"); - auto sub_param = sub_builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "sub_param")); - auto one = sub_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto ones = sub_builder.AddInstruction( - HloInstruction::CreateBroadcast(shape, one, {1})); - auto add = sub_builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); - - module_ = CreateNewModule(); - auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); - sub_computation->CreateFusionInstruction({add, ones}, - HloInstruction::FusionKind::kLoop); - - // Build entry-computation with kCall which calls 'sub_computation'. - auto builder = HloComputation::Builder(TestName()); - - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto reverse = - builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0})); - auto call = builder.AddInstruction( - HloInstruction::CreateCall(shape, {reverse}, sub_computation)); - computation_ = module_->AddEntryComputation(builder.Build()); - - RunAnalysis(); - - EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {}, - *points_to_analysis_)); - EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {}, - *dataflow_analysis_)); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index bc683a1880b010d57e83aa6e9ffa95fda299e1a0..f172b1d87c870270436f7301ed200b47d08431a7 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -151,7 +151,7 @@ Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) { Status FusedIrEmitter::FinishVisit(HloInstruction* root) { fused_root_ = root; - return tensorflow::Status::OK(); + return Status::OK(); } FusedIrEmitter::Generator FusedIrEmitter::GetRootGenerator() const { diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 3312a888443233139841ce7a5e3173f907605e1d..7323abeb2077154f82828bcda3e90eb45a67138a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -333,18 +333,7 @@ llvm::Value* IrArray::EmitArrayElementAddress( } CHECK_EQ(index.size(), ShapeUtil::Rank(*shape_)); - std::vector actual_index; - bool is_implicit_broadcast = false; - // We perform broadcasting when the operand shape has dimension(s) of size - // 1. In this case we fix the index value for that dimension to zero. This - // effectively broadcasts along this dimension. - for (int64 i = 0; i < index.size(); ++i) { - auto dim = shape_->dimensions(i); - actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]); - is_implicit_broadcast |= dim == 1; - } - - if (!is_implicit_broadcast && index.LinearValidOnShape(*shape_)) { + if (index.LinearValidOnShape(*shape_)) { llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent(); return ir_builder->CreateInBoundsGEP( @@ -354,6 +343,15 @@ llvm::Value* IrArray::EmitArrayElementAddress( {index.linear()}, llvm_ir::AsStringRef(name)); } + std::vector actual_index; + for (int64 i = 0; i < index.size(); ++i) { + // When dimension i is of size 1, LLVM optimization is able to replace + // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to + // produce better code in some cases. + auto dim = shape_->dimensions(i); + actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]); + } + // "base_ptr_" has the type of "*" // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element // should be computed by diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 1c00b2aabd182da72e78d2c9c01cbe70cfd8e33c..64b935bbf1fb9033cd2e1259b4639cd3780be711 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -100,6 +100,15 @@ class KernelSupportLibrary { [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); }); } + void For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, + const std::function& for_body_generator) { + For(name, start, end, ir_builder_->getInt64(step), + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); }); + } + void For( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 3978acc132f34b8b195d3772ccf71d0d467984db..0728ccfff7b85e3751f33bc5272a5f22d4e5411a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -39,14 +39,13 @@ LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, const IrArray& target_array, llvm::IRBuilder<>* ir_builder) - : body_emitter_([=](const llvm_ir::IrArray::Index array_index) - -> ::tensorflow::Status { + : body_emitter_([=](const llvm_ir::IrArray::Index array_index) -> Status { // Convert target_element_generator to a BodyEmitter. TF_ASSIGN_OR_RETURN(llvm::Value * target_element, target_element_generator(array_index)); target_array.EmitWriteArrayElement(array_index, target_element, ir_builder); - return tensorflow::Status::OK(); + return Status::OK(); }), shape_(target_array.GetShape()), ir_builder_(ir_builder) {} @@ -124,7 +123,7 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( return {array_index}; } -tensorflow::Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { +Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { for (const IrArray::Index& array_index : EmitIndexAndSetExitBasicBlock(loop_name)) { TF_RETURN_IF_ERROR(body_emitter_(array_index)); @@ -135,7 +134,7 @@ tensorflow::Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { if (exit_bb_ != nullptr) { ir_builder_->SetInsertPoint(exit_bb_); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 9ff497aecd0bc964c929205c7fd410cca87d9b77..b70d28ecd3033eb26629718e50ce48f39b162273 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -38,8 +38,7 @@ using ElementGenerator = // Emits a loop for every element in the given shape. class LoopEmitter { public: - using BodyEmitter = - std::function; + using BodyEmitter = std::function; LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, llvm::IRBuilder<>* ir_builder); @@ -72,7 +71,7 @@ class LoopEmitter { tensorflow::StringPiece loop_name); // Emits a complete loop nest for every element in the given shape. - tensorflow::Status EmitLoop(tensorflow::StringPiece loop_name = ""); + Status EmitLoop(tensorflow::StringPiece loop_name = ""); protected: // An IR emitter that generates the loop body. diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.cc b/tensorflow/compiler/xla/service/llvm_ir/ops.cc index 34899b7400464e4f4f97d301f35ed3b7b083bca1..dacc54742c0897bbd92315f1e33a484aae56bb7f 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.cc @@ -49,22 +49,41 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( for (int64 i = 0; i < rank; ++i) { IrArray::Index dim_index({ir_builder->getInt64(i)}); TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index)); + llvm::Value* output_dim_size = llvm::ConstantInt::get( + start_index[i]->getType(), output_shape.dimensions(i)); + llvm::Value* update_dim_size = llvm::ConstantInt::get( + start_index[i]->getType(), update_shape.dimensions(i)); + + // Clamp the start index so that the update region fits in the operand. + // start_index = clamp(start_index, 0, output_dim_size - update_dim_size) + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + llvm::Value* max_bound = + ir_builder->CreateSub(output_dim_size, update_dim_size); + llvm::Value* zero = llvm::ConstantInt::get(start_index[i]->getType(), 0); + start_index[i] = ir_builder->CreateSelect( + ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SGE, zero, start_index[i]), + zero, start_index[i]); + + start_index[i] = ir_builder->CreateSelect( + ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SLE, max_bound, + start_index[i]), + max_bound, start_index[i]); } auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status { // Calculate output_index, where we'll write the value from update. For // each dimension, // - // output_index[dim] = (start_index[dim] + update_index[dim]) % dim_size. + // output_index[dim] = start_index[dim] + update_index[dim] // IrArray::Index output_index(rank); for (int64 i = 0; i < rank; ++i) { - llvm::Value* dim_size = llvm::ConstantInt::get( - update_index[i]->getType(), output_shape.dimensions(i)); - llvm::Value* start_index0 = ir_builder->CreateZExtOrBitCast( + llvm::Value* start_index0 = ir_builder->CreateSExtOrBitCast( start_index[i], update_index[i]->getType()); - output_index[i] = ir_builder->CreateURem( - ir_builder->CreateAdd(start_index0, update_index[i]), dim_size); + output_index[i] = ir_builder->CreateAdd(start_index0, update_index[i]); } // Do output[output_index] = update[update_index]. diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index f74bcb0b79355c8e69890487266cbc5f2a4500be..3a6a7c25f4b727c7112dbcbcb4f3d892679a0011 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -53,7 +53,7 @@ NameUniquer::NameUniquer(const string& separator) { } string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { - string root = GetSanitizedName(prefix.empty() ? "name" : prefix.ToString()); + string root = GetSanitizedName(prefix.empty() ? "name" : std::string(prefix)); // Strip away numeric suffix (if any). Only recognize separator if it is in // the middle of the name. diff --git a/tensorflow/compiler/xla/service/owning_device_memory.cc b/tensorflow/compiler/xla/service/owning_device_memory.cc new file mode 100644 index 0000000000000000000000000000000000000000..c115bc097f3b1dd810654745b835a977955718c3 --- /dev/null +++ b/tensorflow/compiler/xla/service/owning_device_memory.cc @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/owning_device_memory.h" + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" + +namespace xla { + +void OwningDeviceMemory::Free() { + CHECK(allocator_ != nullptr) + << "Can't call Free() on an inactive (i.e. moved from, Forget()'ten, " + "or Free()'ed) instance."; + auto status = allocator_->Deallocate(device_ordinal_, mem_); + if (!status.ok()) { + LOG(WARNING) << "Deallocating buffer " << mem_.opaque() << " failed."; + } + + allocator_ = nullptr; + mem_ = se::DeviceMemoryBase(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/owning_device_memory.h b/tensorflow/compiler/xla/service/owning_device_memory.h new file mode 100644 index 0000000000000000000000000000000000000000..9cf071f0d9d09dfbf74b15e73caaf542714ec8d5 --- /dev/null +++ b/tensorflow/compiler/xla/service/owning_device_memory.h @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_ + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// Break circular dependency between this file and device_memory_allocator.h. +class DeviceMemoryAllocator; + +// Owning pointer for memory on a device. +// +// OwningDeviceMemory is an owning pointer like std::unique_ptr, but it can +// point to memory that resides on a "device" (e.g. a GPU). When an +// OwningDeviceMemory goes out of scope, it frees the memory it owns. +// +// We say that an instance of OwningDeviceMemory is "active" if it currently +// owns a (possibly empty) slice of memory on the device. Moving, Forget()'ing, +// Free()'ing, and other actions can deactive an active object. +// +// Note that we can't simply use stream_executor::ScopedDeviceMemory instead of +// OwningDeviceMemory, because ScopedDeviceMemory frees its pointer via a +// StreamExecutor. This class needs to free via a xla::DeviceMemoryAllocator. +class OwningDeviceMemory { + public: + OwningDeviceMemory() : device_ordinal_(-1), allocator_(nullptr) {} + + explicit OwningDeviceMemory(se::DeviceMemoryBase mem, int device_ordinal, + DeviceMemoryAllocator* allocator) + : mem_(mem), device_ordinal_(device_ordinal), allocator_(allocator) { + CHECK(allocator != nullptr) << "allocator cannot be null."; + } + + OwningDeviceMemory(OwningDeviceMemory&& other) + : mem_(other.mem_), + device_ordinal_(other.device_ordinal_), + allocator_(other.allocator_) { + other.mem_ = se::DeviceMemoryBase(); + other.allocator_ = nullptr; + } + + OwningDeviceMemory& operator=(OwningDeviceMemory&& other) { + if (allocator_ != nullptr) { + Free(); + } + mem_ = other.mem_; + device_ordinal_ = other.device_ordinal_; + allocator_ = other.allocator_; + + other.mem_ = se::DeviceMemoryBase(); + other.allocator_ = nullptr; + return *this; + } + + // Deactivates this instance if it's active. Nop if it's not active. + OwningDeviceMemory& operator=(std::nullptr_t) { + if (allocator_ != nullptr) { + Free(); + } + return *this; + } + + ~OwningDeviceMemory() { + if (allocator_ != nullptr) { + Free(); + } + } + + // The returned allocator is nonnull iff this object is active. + DeviceMemoryAllocator* allocator() const { return allocator_; } + + int device_ordinal() const { return device_ordinal_; } + + // Gets the device memory pointer. + const void* opaque() const { return mem_.opaque(); } + void* opaque() { return mem_.opaque(); } + + uint64 size() const { return mem_.size(); } + + // Determines whether this wraps a null pointer. + // + // !is_null() is sufficient but not necessary to imply `this` is active. + bool is_null() const { return mem_.is_null(); } + + se::DeviceMemoryBase AsDeviceMemoryBase() { + return se::DeviceMemoryBase(opaque(), size(), /*is_sub_buffer=*/false); + } + + // Returns the wrapped DeviceMemoryBase without freeing it, and deactivates + // this object. Precondition: `this` is active. + TF_MUST_USE_RESULT se::DeviceMemoryBase Forget() { + CHECK(allocator_ != nullptr) + << "Can't call Forget() on an inactive (i.e. moved from, Forget()'ten, " + "or Free()'ed) instance."; + allocator_ = nullptr; + se::DeviceMemoryBase mem(mem_); + mem_ = se::DeviceMemoryBase(); + return mem; + } + + // Frees the wrapped DeviceMemoryBase and deactivates this object. + // Precondition: `this` is active. + void Free(); + + private: + se::DeviceMemoryBase mem_; + int device_ordinal_; + DeviceMemoryAllocator* allocator_; // Null if this object is inactive. +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_ diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 586f6ef7a9c4f17f69340e77be17aec2f677a791..d3bc47e61e0e75fa2ef181988700f88cec9c1d76 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -702,6 +702,30 @@ class HloInstructionPatternOperandImpl { HloInstructionPattern operand_; }; +// An HloInstructionPattern implementation that matches only if the instruction +// is a fusion node with a particular kind. +template +class HloInstructionPatternFusionKindImpl { + public: + explicit constexpr HloInstructionPatternFusionKindImpl( + const Previous& previous, ::xla::HloInstruction::FusionKind kind) + : previous_(previous), kind_(kind) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && + inst->fusion_kind() == kind_; + } + + bool Match(::xla::HloInstruction* inst) const { + return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && + inst->fusion_kind() == kind_; + } + + private: + Previous previous_; + ::xla::HloInstruction::FusionKind kind_; +}; + // A pattern that matches HloInstructions. template class HloInstructionPattern { @@ -807,6 +831,16 @@ class HloInstructionPattern { matched_inst_); } + // Modifies the pattern to match only if the instruction is a fusion node with + // the given kind. + constexpr HloInstructionPattern> + WithFusionKind(HloInstruction::FusionKind kind) const { + return HloInstructionPattern>( + HloInstructionPatternFusionKindImpl(impl_, kind), matched_inst_); + } + private: Impl impl_; HloInstructionType** matched_inst_; diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index c88157c312524fb273e6df368d2ef61d679d1d8b..204e8c99209fa95adb868a676bb9e5144fed432c 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -170,5 +170,28 @@ TEST(PatternMatcherTest, TupleShape) { Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape()))); } +TEST(PatternMatcherTest, FusionKind) { + constexpr char kModuleStr[] = R"( + HloModule test_module + + fused_computation { + ROOT fp0 = f32[] parameter(0) + } + + ENTRY while.v11 { + p0 = f32[] parameter(0) + ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr)); + + auto* root = hlo_module->entry_computation()->root_instruction(); + EXPECT_TRUE(Match( + root, match::Op().WithFusionKind(HloInstruction::FusionKind::kLoop))); + EXPECT_FALSE(Match( + root, match::Op().WithFusionKind(HloInstruction::FusionKind::kInput))); + EXPECT_FALSE(Match(root->operand(0), match::Op().WithFusionKind( + HloInstruction::FusionKind::kLoop))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 495f8801ba82ecbcf9f6e5db5507ef8785c752d6..cb0f76ebe4d445059fdf37ebf559bef851a57104 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -64,7 +64,7 @@ namespace { // Records the arguments used to invoke a computation in a SessionModule // proto. -tensorflow::Status RecordArguments( +Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, se::StreamExecutor* executor, TransferManager* transfer_manager, SessionModule* module) { @@ -75,24 +75,22 @@ tensorflow::Status RecordArguments( transfer_manager->TransferLiteralFromDevice(executor, *argument)); *module->add_arguments() = literal->ToProto(); } - return tensorflow::Status::OK(); + return Status::OK(); } // Records the result of a computation in a SessionModule proto. -tensorflow::Status RecordResult(const ShapedBuffer& result, - se::StreamExecutor* executor, - TransferManager* transfer_manager, - SessionModule* module) { +Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor, + TransferManager* transfer_manager, SessionModule* module) { module->clear_result(); TF_ASSIGN_OR_RETURN( std::unique_ptr literal, transfer_manager->TransferLiteralFromDevice(executor, result)); *module->mutable_result() = literal->ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } // Records the arguments used to invoke a computation in an HloSnapshot proto. -tensorflow::Status RecordArguments( +Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, se::StreamExecutor* executor, TransferManager* transfer_manager, HloSnapshot* module) { @@ -103,20 +101,18 @@ tensorflow::Status RecordArguments( transfer_manager->TransferLiteralFromDevice(executor, *argument)); *module->add_arguments() = literal->ToProto(); } - return tensorflow::Status::OK(); + return Status::OK(); } // Records the result of a computation in a HloSnapshot proto. -tensorflow::Status RecordResult(const ShapedBuffer& result, - se::StreamExecutor* executor, - TransferManager* transfer_manager, - HloSnapshot* module) { +Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor, + TransferManager* transfer_manager, HloSnapshot* module) { module->clear_result(); TF_ASSIGN_OR_RETURN( std::unique_ptr literal, transfer_manager->TransferLiteralFromDevice(executor, result)); *module->mutable_result() = literal->ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace @@ -199,8 +195,8 @@ Service::Service(const ServiceOptions& options, } } -tensorflow::Status Service::Computation(const ComputationRequest* arg, - ComputationResponse* result) { +Status Service::Computation(const ComputationRequest* arg, + ComputationResponse* result) { if (arg->name().empty()) { return InvalidArgument("computation request needs a name"); } @@ -210,24 +206,23 @@ tensorflow::Status Service::Computation(const ComputationRequest* arg, VLOG(1) << Printf("Created new computation %s on service %p, name %s", result->computation().ShortDebugString().c_str(), this, arg->name().c_str()); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) { +Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) { *result->mutable_channel() = channel_tracker_.NewChannel(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) { +Status Service::Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) { return allocation_tracker_.Unregister(arg->data()); } // Deconstructs a previously-allocated global handle. -tensorflow::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) { +Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) { TF_ASSIGN_OR_RETURN( std::vector elements, allocation_tracker_.DeconstructTuple(arg->tuple_handle())); @@ -235,11 +230,11 @@ tensorflow::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, for (auto& element : elements) { *result->add_element_handles() = element; } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ValidateResultShapeWithLayout( - const Shape& shape_with_layout, const Shape& result_shape) const { +Status Service::ValidateResultShapeWithLayout(const Shape& shape_with_layout, + const Shape& result_shape) const { if (!ShapeUtil::Compatible(shape_with_layout, result_shape)) { return InvalidArgument( "Shape used to set computation result layout %s is not compatible " @@ -345,6 +340,9 @@ StatusOr> Service::CreateModuleConfig( // If the result layout is not set, then choose the default. // TODO(b/29118294): Allow the compiler to choose a better layout in this // case. + // TODO(b/78356948): We are forcing the default layout here. We should fix + // clients which expect a default layout, to be explicit about it, by + // passing the proper ExecutionOptions with shape_with_output_layout set. host_computation_layout->mutable_result_layout()->SetToDefaultLayout(); device_computation_layout->mutable_result_layout()->SetToDefaultLayout(); } @@ -511,7 +509,7 @@ Status Service::ValidateEntryComputationLayout(HloModule* module) { module->device_entry_computation_layout().result_shape(), execute_backend_->transfer_manager()->HostShapeToDeviceShape( module->host_entry_computation_layout().result_shape()))); - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr> Service::BuildExecutable( @@ -801,8 +799,8 @@ StatusOr Service::ExecuteAndRegisterResult( result_tag); } -tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) { +Status Service::SetReturnValue(const SetReturnValueRequest* arg, + SetReturnValueResponse* results) { TF_ASSIGN_OR_RETURN(UserComputation * computation, computation_tracker_.Resolve(arg->computation())); return computation->SetReturnValue(arg->operand()); @@ -849,8 +847,8 @@ StatusOr>> Service::GetArguments( return replicated_arguments; } -tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) { +Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) { VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); std::vector>> all_arguments; @@ -957,11 +955,11 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, } VLOG(1) << "successfully completed 'execute-parallel' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) { +Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) { VLOG(1) << "running execute-graph-parallel request"; std::vector>> all_arguments; @@ -1058,11 +1056,11 @@ tensorflow::Status Service::ExecuteGraphParallel( } VLOG(1) << "successfully completed 'execute-graph-parallel' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) { +Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) { const int64 available_device_count = execute_backend_->device_count(); const int64 replica_count = options_.number_of_replicas(); if (replica_count <= 0) { @@ -1082,11 +1080,11 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, *result->add_device_handles() = device_handle; } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ExecuteOneToN(const ExecuteRequest* arg, - ExecuteResponse* result) { +Status Service::ExecuteOneToN(const ExecuteRequest* arg, + ExecuteResponse* result) { ExecuteParallelRequest parallel_arg; *parallel_arg.add_requests() = *arg; ExecuteParallelResponse parallel_result; @@ -1094,8 +1092,8 @@ tensorflow::Status Service::ExecuteOneToN(const ExecuteRequest* arg, return PickParallelResponse(parallel_result, result); } -tensorflow::Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, - ExecuteResponse* result) { +Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { ExecuteGraphParallelRequest parallel_arg; *parallel_arg.add_requests() = *arg; ExecuteParallelResponse parallel_result; @@ -1103,7 +1101,7 @@ tensorflow::Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, return PickParallelResponse(parallel_result, result); } -tensorflow::Status Service::PickParallelResponse( +Status Service::PickParallelResponse( const ExecuteParallelResponse& parallel_result, ExecuteResponse* result) { // The "result device" selection is a bit hacky, but better than assuming it // is device 0. We have b/76035356 for restructuring the client API to clean @@ -1126,8 +1124,7 @@ tensorflow::Status Service::PickParallelResponse( return Status::OK(); } -tensorflow::Status Service::Execute(const ExecuteRequest* arg, - ExecuteResponse* result) { +Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { VLOG(1) << "running execute request: " << arg->ShortDebugString(); TF_ASSIGN_OR_RETURN(UserComputation * user_computation, @@ -1198,7 +1195,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, } VLOG(1) << "successfully completed 'execute' request"; - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr> Service::BuildExecutable( @@ -1243,8 +1240,8 @@ StatusOr> Service::BuildExecutable( return std::move(executable); } -tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) { +Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { VLOG(1) << "running execute-graph request"; if (!arg->has_computation()) { @@ -1303,11 +1300,11 @@ tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, } VLOG(1) << "successfully completed 'execute-graph' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) { +Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) { VLOG(1) << "running execute-async request: " << arg->ShortDebugString(); TF_ASSIGN_OR_RETURN(UserComputation * user_computation, @@ -1383,11 +1380,11 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, streams.clear(); VLOG(1) << "successfully completed 'execute-async' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) { +Status Service::WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) { TF_ASSIGN_OR_RETURN(const auto execution, execution_tracker_.Resolve(arg->execution())); @@ -1398,11 +1395,11 @@ tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg, TF_RETURN_IF_ERROR(execution_tracker_.Unregister(arg->execution())); VLOG(1) << "successfully completed 'wait-for-execution' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, - TransferToClientResponse* result) { +Status Service::TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, allocation_tracker_.ResolveForReplica(arg->data(), 0)); @@ -1432,7 +1429,7 @@ tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, *result->mutable_literal() = result_literal->Relayout(*return_shape)->ToProto(); } - return tensorflow::Status::OK(); + return Status::OK(); } namespace { @@ -1450,8 +1447,8 @@ std::unique_ptr CloneShapedBufferOnDevice( } // namespace -tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, - TransferToServerResponse* result) { +Status Service::TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) { TF_ASSIGN_OR_RETURN(std::unique_ptr literal, Literal::CreateFromProto(arg->literal())); const Shape& shape = literal->shape(); @@ -1484,11 +1481,11 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, StrCat("TransferToServer literal of shape ", ShapeUtil::HumanString(shape)))); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) { +Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) { const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( @@ -1517,9 +1514,8 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, executor, *literal); } -tensorflow::Status Service::TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) { +Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) { const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( @@ -1545,16 +1541,16 @@ tensorflow::Status Service::TransferFromOutfeed( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( executor, arg->shape_with_layout(), &literal)); *result->mutable_literal() = literal.ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) { +Status Service::ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) { return execute_backend_->ResetDevices(); } -tensorflow::Status Service::IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) { +Status Service::IsConstant(const IsConstantRequest* arg, + IsConstantResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * user_computation, computation_tracker_.Resolve(arg->computation())); @@ -1570,11 +1566,11 @@ tensorflow::Status Service::IsConstant(const IsConstantRequest* arg, user_computation->IsConstant(arg->operand(), arg->num_parameters())); result->set_is_constant(is_constant); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) { +Status Service::ComputeConstant(const ComputeConstantRequest* arg, + ComputeConstantResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * user_computation, computation_tracker_.Resolve(arg->computation())); @@ -1661,11 +1657,11 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, } *result->mutable_literal() = result_literal->ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) { +Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) { if (!arg->has_computation()) { return InvalidArgument("computations may not be empty"); } @@ -1703,20 +1699,18 @@ tensorflow::Status Service::ComputeConstantGraph( } *result->mutable_literal() = result_literal->ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) { +Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, allocation_tracker_.ResolveForReplica(arg->data(), 0)); *result->mutable_shape() = buffer->on_host_shape(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) { +Status Service::GetComputationShape(const GetComputationShapeRequest* arg, + GetComputationShapeResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * computation, computation_tracker_.Resolve(arg->computation())); @@ -1726,21 +1720,21 @@ tensorflow::Status Service::GetComputationShape( TF_ASSIGN_OR_RETURN(auto program_shape, computation->ComputeProgramShape( versioned_handle.version)); *result->mutable_program_shape() = *program_shape; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) { +Status Service::GetLocalShape(const GetLocalShapeRequest* arg, + GetLocalShapeResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * computation, computation_tracker_.Resolve(arg->computation())); TF_ASSIGN_OR_RETURN(*result->mutable_shape(), computation->GetShape(arg->operand())); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetComputationStats( - const ComputationStatsRequest* arg, ComputationStatsResponse* result) { +Status Service::GetComputationStats(const ComputationStatsRequest* arg, + ComputationStatsResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * user_computation, computation_tracker_.Resolve(arg->computation())); @@ -1766,10 +1760,10 @@ tensorflow::Status Service::GetComputationStats( stats.set_flop_count(analysis.flop_count()); stats.set_transcendental_count(analysis.transcendental_count()); *result->mutable_stats() = stats; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetComputationGraphStats( +Status Service::GetComputationGraphStats( const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) { if (!arg->has_computation()) { return InvalidArgument("Computations may not be empty."); @@ -1796,11 +1790,11 @@ tensorflow::Status Service::GetComputationGraphStats( stats.set_flop_count(analysis.flop_count()); stats.set_transcendental_count(analysis.transcendental_count()); *result->mutable_stats() = stats; - return tensorflow::Status::OK(); + return Status::OK(); } template -tensorflow::Status Service::AddInstruction( +Status Service::AddInstruction( const RequestT* arg, ResponseT* result, const std::function(UserComputation*)>& adder) { @@ -1808,10 +1802,10 @@ tensorflow::Status Service::AddInstruction( computation_tracker_.Resolve(arg->computation())); TF_ASSIGN_OR_RETURN(*result->mutable_output(), adder(computation)); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { +Status Service::Op(const OpRequest* arg, OpResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * computation, computation_tracker_.Resolve(arg->computation())); StatusOr handle_status; @@ -2033,27 +2027,26 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { if (arg->has_sharding()) { TF_RETURN_IF_ERROR(computation->SetOpSharding(handle, arg->sharding())); } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::SnapshotComputation( - const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) { +Status Service::SnapshotComputation(const SnapshotComputationRequest* arg, + SnapshotComputationResponse* result) { TF_ASSIGN_OR_RETURN( std::unique_ptr module, computation_tracker_.SnapshotComputation(arg->computation())); result->set_allocated_module(module.release()); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::LoadComputationSnapshot( +Status Service::LoadComputationSnapshot( const LoadComputationSnapshotRequest* arg, LoadComputationSnapshotResponse* result) { TF_ASSIGN_OR_RETURN(*result->mutable_computation(), computation_tracker_.LoadSessionModule(arg->module())); - return tensorflow::Status::OK(); + return Status::OK(); } DeviceHandle Service::SingleComputationDeviceHandle() const { diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index f84fe407e05da371da66ba33efd6e8165198cf2c..81fbd41957887aec763e1cfe165ad0d1d2ac2269 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -85,55 +85,52 @@ class Service : public ServiceInterface { // Creates a new computation with the given name. // A unique ComputationHandle is returned. - tensorflow::Status Computation(const ComputationRequest* arg, - ComputationResponse* result) override; + Status Computation(const ComputationRequest* arg, + ComputationResponse* result) override; // Unregisters a previously-allocated global handle. // // If the handle given is not currently allocated, a NOT_FOUND status is // returned. - tensorflow::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) override; + Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) override; // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each // element in the tuple. - tensorflow::Status DeconstructTuple( - const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) override; + Status DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) override; // Modifies the provided computation so that subsequent executions // will compute the provided ComputationDataHandle, rather than the // last expression enqueued on that Computation. - tensorflow::Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; + Status SetReturnValue(const SetReturnValueRequest* arg, + SetReturnValueResponse* results) override; // Executes a computation with the provided global data passed as // immutable arguments. Returns global data output and execution timing. - tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) override; + Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; // Executes a computation with the provided global data passed as // immutable arguments. The request contains the whole computation graph. // Returns global data output and execution timing. // // TODO(b/74197823): This is a part of a NOT YET ready refactor. - tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) override; + Status ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) override; // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each // computation. - tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override; + Status ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) override; // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each // computation. // // TODO(b/74197823): This is a part of a NOT YET ready refactor. - tensorflow::Status ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, - ExecuteParallelResponse* result) override; + Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) override; // Requests one or more device handles from the target. // @@ -143,9 +140,8 @@ class Service : public ServiceInterface { // the first set of replicas, and the next R devices to the second set of // replicas, etc. Each returned device handle represents the device with the // replica id 0. - tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override; + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override; // Asynchronously executes a computation with provided arguments. Invokes // the provided computation with the provided global data passed as @@ -154,38 +150,33 @@ class Service : public ServiceInterface { // (Note: The corresponding function in xla::Client was removed as part of // b/64116060, in an attempt to simplify our API. We're keeping this around // for now in case we want to expose this to clients in a different way.) - tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; + Status ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) override; // Waits until the specified execution is complete and returns the result. // Calling this API multiple times with the same execution handle returns the // method with an error since the execution handle is destroyed after the // first call. - tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) override; + Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override; // Requests that global data be transferred to the client in literal form. - tensorflow::Status TransferToClient( - const TransferToClientRequest* arg, - TransferToClientResponse* result) override; + Status TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) override; // Transfers data from a literal provided by the client, into device memory. - tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, - TransferToServerResponse* result) override; + Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) override; // Transfers data from a literal provided by the client, into the Infeed // buffer of the device. - tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override; + Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override; // Transfers data from the Outfeed othe device to the literal provided by the // client. - tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override; + Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override; // Resets devices, clearing all existing state on all the devices associated // with this service (including memory allocated on the devices). @@ -196,71 +187,65 @@ class Service : public ServiceInterface { // ResetDevice should be called before an Execution that expect the device to // be in the reset state. For example, if the prior Execution modifies device // state (e.g., architectural state) that the next Execution depends on. - tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override; + Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override; // Tests if an expression is a compile-time constant. - tensorflow::Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) override; + Status IsConstant(const IsConstantRequest* arg, + IsConstantResponse* result) override; // Computes the value of a constant expression. - tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; - tensorflow::Status ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, - ComputeConstantResponse* result) override; + Status ComputeConstant(const ComputeConstantRequest* arg, + ComputeConstantResponse* result) override; + Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) override; // Returns the shape (with layout) of an array associated with a given data // handle. - tensorflow::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) override; + Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) override; // Returns the program shape of the computation associated with the given // handle. - tensorflow::Status GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; + Status GetComputationShape(const GetComputationShapeRequest* arg, + GetComputationShapeResponse* result) override; ///// // Computation-oriented methods. // Enqueues an Op on the computation. - tensorflow::Status Op(const OpRequest* arg, OpResponse* result) override; + Status Op(const OpRequest* arg, OpResponse* result) override; // Retrieves the inferred shape for a value within a computation. - tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; + Status GetLocalShape(const GetLocalShapeRequest* arg, + GetLocalShapeResponse* result) override; // Retrieves the statistics of a computation. - tensorflow::Status GetComputationStats( - const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; + Status GetComputationStats(const ComputationStatsRequest* arg, + ComputationStatsResponse* result) override; // Retrieves the statistics of a computation. // // TODO(b/74197823): This is a part of a NOT YET ready refactor. - tensorflow::Status GetComputationGraphStats( - const ComputationGraphStatsRequest* arg, - ComputationStatsResponse* result) override; + Status GetComputationGraphStats(const ComputationGraphStatsRequest* arg, + ComputationStatsResponse* result) override; // Snapshots the current state of a computation handle into a serializable // protocol buffer form, so it can be loaded via // LoadComputationSnapshot. - tensorflow::Status SnapshotComputation( - const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) override; + Status SnapshotComputation(const SnapshotComputationRequest* arg, + SnapshotComputationResponse* result) override; // Loads a computation from a serialized protocol buffer created via // SnapshotComputation. - tensorflow::Status LoadComputationSnapshot( + Status LoadComputationSnapshot( const LoadComputationSnapshotRequest* arg, LoadComputationSnapshotResponse* result) override; // Creates a unique channel handle that can be used for Send/Recv // instructions. - tensorflow::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) override; + Status CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) override; // Returns the ComputationTracker of the current service instance. // Only used in unit tests to access user computations from client. @@ -389,7 +374,7 @@ class Service : public ServiceInterface { // Convenience function for adding a function to a user computation. template - tensorflow::Status AddInstruction( + Status AddInstruction( const RequestT* arg, ResponseT* result, const std::function(UserComputation*)>& adder); @@ -397,16 +382,14 @@ class Service : public ServiceInterface { // Executes a single computation which has more than one target device. // The N devices are expected to all return an empty tuple, but one, which // will be the result of this computation. - tensorflow::Status ExecuteOneToN(const ExecuteRequest* arg, - ExecuteResponse* result); - tensorflow::Status ExecuteOneToN(const ExecuteGraphRequest* arg, - ExecuteResponse* result); + Status ExecuteOneToN(const ExecuteRequest* arg, ExecuteResponse* result); + Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result); // Convenience function which checks whether the given shape_with_layout // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. - tensorflow::Status ValidateResultShapeWithLayout( - const Shape& shape_with_layout, const Shape& result_shape) const; + Status ValidateResultShapeWithLayout(const Shape& shape_with_layout, + const Shape& result_shape) const; // Returns the stream executors assigned to the replicas represented by the // given device handle. Each device_handle is a virtual replicated device that diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 48b2922e77b78719e5d3469cbaa4fc15969de91b..3500978bdd808f0c7684d14a05636d90105aa594 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -58,6 +58,8 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) { return UNOP_COS; case HloOpcode::kExp: return UNOP_EXP; + case HloOpcode::kExpm1: + return UNOP_EXPM1; case HloOpcode::kFloor: return UNOP_FLOOR; case HloOpcode::kImag: @@ -66,6 +68,8 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) { return UNOP_IS_FINITE; case HloOpcode::kLog: return UNOP_LOG; + case HloOpcode::kLog1p: + return UNOP_LOG1P; case HloOpcode::kNot: return UNOP_NOT; case HloOpcode::kNegate: @@ -168,24 +172,24 @@ bool AllUnique(tensorflow::gtl::ArraySlice slice) { return std::set(slice.begin(), slice.end()).size() == slice.size(); } -tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape, - tensorflow::StringPiece op_type) { +Status ExpectNotTupleOrOpaque(const Shape& shape, + tensorflow::StringPiece op_type) { if (ShapeUtil::IsTuple(shape)) { return InvalidArgument("Expected non-tuple argument for %s, but got %s.", - op_type.ToString().c_str(), + std::string(op_type).c_str(), ShapeUtil::HumanString(shape).c_str()); } else if (ShapeUtil::IsOpaque(shape)) { return InvalidArgument("Expected non-opaque argument for %s, but got %s.", - op_type.ToString().c_str(), + std::string(op_type).c_str(), ShapeUtil::HumanString(shape).c_str()); } else { - return tensorflow::Status::OK(); + return Status::OK(); } } -tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, - const Shape& init_value_shape, - const PrimitiveType& input_element_type) { +Status VerifyReducerShape(const ProgramShape& reducer_shape, + const Shape& init_value_shape, + const PrimitiveType& input_element_type) { if (reducer_shape.parameters_size() != 2) { return InvalidArgument( "Reduction function must take 2 parameters, but " @@ -245,7 +249,7 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, ShapeUtil::HumanString(accumulator_shape).c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr InferWindowOutputShape(const Shape& base_shape, @@ -337,7 +341,9 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, case UNOP_COS: case UNOP_SIN: case UNOP_EXP: + case UNOP_EXPM1: case UNOP_LOG: + case UNOP_LOG1P: case UNOP_TANH: if (!ShapeUtil::ElementIsFloating(arg) && !ShapeUtil::ElementIsComplex(arg)) { @@ -1212,11 +1218,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( scale_shape, "scale input of batch norm training")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == - tensorflow::Status::OK()); + Status::OK()); if (feature_index >= ShapeUtil::Rank(operand_shape)) { return InvalidArgument( @@ -1318,15 +1324,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( scale_shape, "scale input of batch norm inference")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) == - tensorflow::Status::OK()); + Status::OK()); if (feature_index >= ShapeUtil::Rank(operand_shape)) { return InvalidArgument( diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index fb3b5f06dad67b4305aed0305c9f6441e666db53..7d7dcac10b65933d1c81b8aca77465932694bfdb 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shaped_buffer.h" -#include #include #include @@ -25,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" @@ -123,6 +123,8 @@ ScopedShapedBuffer::ScopedShapedBuffer(ScopedShapedBuffer&& s) } ScopedShapedBuffer& ScopedShapedBuffer::operator=(ScopedShapedBuffer&& s) { + Deallocate(); + *static_cast(this) = std::move(static_cast(s)); allocator_ = s.allocator_; // Null out s.allocator_ so it doesn't try to free anything in its destructor. @@ -130,7 +132,15 @@ ScopedShapedBuffer& ScopedShapedBuffer::operator=(ScopedShapedBuffer&& s) { return *this; } -ScopedShapedBuffer::~ScopedShapedBuffer() { +ScopedShapedBuffer::~ScopedShapedBuffer() { Deallocate(); } + +ShapedBuffer ScopedShapedBuffer::release() { + ShapedBuffer shaped_buffer(static_cast(*this)); + buffers_ = ShapeTree(); + return shaped_buffer; +} + +void ScopedShapedBuffer::Deallocate() { // allocator_ will be null if we were moved-from. if (allocator_ == nullptr) { return; @@ -138,22 +148,14 @@ ScopedShapedBuffer::~ScopedShapedBuffer() { // Deallocate all non-null buffers. A buffer may appear in more than one spot // in the shape (eg, a tuple with a repeated element) so keep track of what // has been deallocated. - std::set deallocated_opaques; + tensorflow::gtl::FlatSet deallocated_ptrs; for (auto& pair : buffers_) { se::DeviceMemoryBase& memory_base = pair.second; if (!memory_base.is_null() && - deallocated_opaques.count(memory_base.opaque()) == 0) { - deallocated_opaques.insert(memory_base.opaque()); - TF_CHECK_OK( - this->allocator_->Deallocate(this->device_ordinal(), &memory_base)); + deallocated_ptrs.insert(memory_base.opaque()).second) { + TF_CHECK_OK(allocator_->Deallocate(device_ordinal(), memory_base)); } } } -ShapedBuffer ScopedShapedBuffer::release() { - ShapedBuffer shaped_buffer(static_cast(*this)); - buffers_ = ShapeTree(); - return shaped_buffer; -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index e10fca9e9466c018f6cb4da2f5618e4db4977307..905a7e82e621f2bf4588b71be5dbab20f892cafe 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -148,13 +148,29 @@ class ScopedShapedBuffer : public ShapedBuffer { // ScopedShapedBuffer. DeviceMemoryAllocator* memory_allocator() const { return allocator_; } - // Releases all device memory owned by this ScopedShapedBuffer and returns the - // device memory pointers in the form of a ShapedBuffer. The returned - // ShapedBuffer takes over the memory from the ScopedShapedBuffer. The - // resulting ScopedShapedBuffer can only be destroyed. - ShapedBuffer release(); + // Sets the device memory buffer at the given index. + // + // If the given buffer's device memory is non-null, its device_ordinal and + // allocator must match those in `this`. + void set_buffer(OwningDeviceMemory buffer, const ShapeIndex& index) { + if (!buffer.is_null()) { + CHECK_EQ(buffer.device_ordinal(), device_ordinal()); + CHECK_EQ(buffer.allocator(), allocator_); + *buffers_.mutable_element(index) = buffer.Forget(); + } else { + *buffers_.mutable_element(index) = se::DeviceMemoryBase(); + } + } + + // Like unique_ptr::release(), creates and returns a regular ShapedBuffer from + // this ScopedShapedBuffer, without freeing any of the associated memory. + // + // It's the caller's job to ensure that the memory contained therein is freed. + TF_MUST_USE_RESULT ShapedBuffer release(); protected: + void Deallocate(); + DeviceMemoryAllocator* allocator_; }; diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0fc243667911651c788e3c1e5f1d39d86170f1ad --- /dev/null +++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/shaped_buffer.h" + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace xla { +namespace { + +TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) { + TF_ASSERT_OK_AND_ASSIGN(auto platforms, + xla::PlatformUtil::GetSupportedPlatforms()); + ASSERT_FALSE(platforms.empty()); + auto* platform = platforms[0]; + TF_ASSERT_OK_AND_ASSIGN(auto executors, + xla::PlatformUtil::GetStreamExecutors(platform)); + xla::StreamExecutorMemoryAllocator allocator(platform, executors); + const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {}); + const int kDeviceOrdinal = 0; + auto scoped_buffer = tensorflow::MakeUnique( + shape, shape, &allocator, kDeviceOrdinal); + std::unique_ptr buffer = std::move(scoped_buffer); + buffer = nullptr; +} + +class TestAllocator : public DeviceMemoryAllocator { + public: + TestAllocator() + : DeviceMemoryAllocator(PlatformUtil::GetDefaultPlatform().ValueOrDie()) { + } + + ~TestAllocator() override { + if (!allocations_.empty()) { + ADD_FAILURE() << "Some allocations not freed!"; + } + } + + // Pull in two-arg overload of Allocate. + using DeviceMemoryAllocator::Allocate; + + StatusOr Allocate(int device_ordinal, uint64 size, + bool /*retry_on_failure*/) override { + // By contract, we must return null if size == 0. + if (size == 0) { + return OwningDeviceMemory(); + } + void* buf = malloc(size); + allocations_.insert({device_ordinal, buf}); + return OwningDeviceMemory(se::DeviceMemoryBase(buf, size), device_ordinal, + this); + } + + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override { + if (mem.is_null()) { + return Status::OK(); + } + + auto it = allocations_.find({device_ordinal, mem.opaque()}); + if (it == allocations_.end()) { + ADD_FAILURE() << "Allocation not found (double free?)"; + } else { + free(mem.opaque()); + allocations_.erase(it); + } + return Status::OK(); + } + + bool AllowsAsynchronousDeallocation() const override { return false; } + + private: + std::set> allocations_; +}; + +TEST(ScopedShapedBufferTest, TestMoveAssignmentOperator) { + Shape s = ShapeUtil::MakeShape(F32, {1}); + TestAllocator allocator; + ScopedShapedBuffer sb1(s, s, &allocator, /*device_ordinal=*/0); + sb1.set_buffer( + allocator.Allocate(/*device_ordinal=*/0, /*size=*/42).ValueOrDie(), + /*index=*/{}); + + ScopedShapedBuffer sb2(s, s, &allocator, /*device_ordinal=*/1); + sb2.set_buffer( + allocator.Allocate(/*device_ordinal=*/1, /*size=*/10).ValueOrDie(), + /*index=*/{}); + + sb1 = std::move(sb2); + + // TestAllocator's destructor checks that all memory was freed. +} + +} // anonymous namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 8b71a415091f028b3167cddb2583754e72ba17c8..c4d01562c4e32225ebb984d8fcd93ec3fa86e403 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -37,7 +37,7 @@ TransferManager::GetPlatformTransferManagers() { } Status TransferManager::TransferArrayToDevice( - se::StreamExecutor* executor, const Literal& literal, + se::StreamExecutor* executor, const LiteralSlice& literal, const se::DeviceMemoryBase& dest) { const Shape on_device_shape = HostShapeToDeviceShape(literal.shape()); TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape)) @@ -196,9 +196,11 @@ StatusOr TransferManager::AllocateScopedShapedBuffer( const ShapeIndex& index = pair.first; se::DeviceMemoryBase& memory_base = pair.second; const Shape& subshape = ShapeUtil::GetSubshape(on_device_shape, index); - TF_ASSIGN_OR_RETURN(memory_base, + TF_ASSIGN_OR_RETURN(auto memory, allocator->Allocate(shaped_buffer.device_ordinal(), GetByteSizeRequirement(subshape))); + // Move the allocated buffer into the ScopedShapedBuffer, which owns it. + memory_base = memory.Forget(); } return std::move(shaped_buffer); diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index d82b4f0f81b5da38c1caf80bddefa0d3f7842463..43a8092b06fba0e2495bce0ee1a309c85a908273 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -65,14 +65,14 @@ class TransferManager { // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, // but need not have the same layout virtual Status TransferLiteralToDevice(se::StreamExecutor* executor, - const Literal& literal, + const LiteralSlice& literal, const ShapedBuffer& device_buffer) = 0; // Convenience methods for transferring an array to or from the device at a // known address. This avoids having to construct a ShapedBuffer just to // transfer an array at a known address. Status TransferArrayToDevice(se::StreamExecutor* executor, - const Literal& literal, + const LiteralSlice& literal, const se::DeviceMemoryBase& dest); StatusOr> TransferArrayFromDevice( se::StreamExecutor* executor, const Shape& shape, @@ -81,7 +81,7 @@ class TransferManager { // Transfers the given literal into the Infeed interface of the device, // using the given executor. virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) = 0; + const LiteralSlice& literal) = 0; // Transfers the given literal from the Outfeed interface of the device, // using the given executor. diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 3efd38ce0daa3e3f3398b32463019df6cd10a009..ba16dc640e2d2974eab4fc8b134a6e33c03e3b85 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -35,7 +35,8 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( const HloInstruction& dot, const TransposeFolding::TransposableGemmOperandsFn& transposable_gemm_operands) { - if (HloOpcode::kDot != dot.opcode()) { + if (HloOpcode::kDot != dot.opcode() || + dot.dot_dimension_numbers().lhs_batch_dimensions_size() != 0) { return {}; } @@ -44,6 +45,8 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( auto& operand = *dot.operand(i); if (operand.IsRank2Transpose()) { operand_set.push_back(i); + } else if (ShapeUtil::Rank(operand.shape()) != 2) { + return {}; } } @@ -74,23 +77,39 @@ using InstructionOperandsPair = // Folds the operands of `dot` that are foldable transposes. `computation` is // the parent HLO computation of `dot`. -// -// Returns whether the module is changed. -bool FoldTransposeIntoDot(InstructionOperandsPair pair) { - auto* dot = pair.first; - std::vector instructions_to_fuse(1, dot); - for (const int64 operand_index : pair.second) { - instructions_to_fuse.push_back(dot->mutable_operand(operand_index)); - } - - // Early-exit if no operands are foldable. - if (instructions_to_fuse.size() == 1) { - return false; +Status FoldTransposeIntoDot(InstructionOperandsPair pair) { + HloInstruction* dot = pair.first; + + DotDimensionNumbers new_dim_numbers = dot->dot_dimension_numbers(); + HloInstruction* new_lhs = dot->mutable_operand(0); + HloInstruction* new_rhs = dot->mutable_operand(1); + + CHECK_EQ(new_dim_numbers.lhs_batch_dimensions_size(), 0); + CHECK_EQ(new_dim_numbers.rhs_batch_dimensions_size(), 0); + CHECK_EQ(new_dim_numbers.lhs_contracting_dimensions_size(), 1); + CHECK_EQ(new_dim_numbers.rhs_contracting_dimensions_size(), 1); + + for (int64 operand_index : pair.second) { + // We've checked that there aren't any batch dimensions and that the inputs + // are rank 2, and shape inference guarantees that there is exactly one + // contracting dimension. + if (operand_index == 0) { + CHECK_EQ(new_lhs->opcode(), HloOpcode::kTranspose); + new_dim_numbers.set_lhs_contracting_dimensions( + 0, 1 - new_dim_numbers.lhs_contracting_dimensions(0)); + new_lhs = new_lhs->mutable_operand(0); + } else { + CHECK_EQ(operand_index, 1); + CHECK_EQ(new_rhs->opcode(), HloOpcode::kTranspose); + new_dim_numbers.set_rhs_contracting_dimensions( + 0, 1 - new_dim_numbers.rhs_contracting_dimensions(0)); + new_rhs = new_rhs->mutable_operand(0); + } } - dot->parent()->CreateFusionInstruction( - instructions_to_fuse, HloInstruction::FusionKind::kTransposeDot); - return true; + std::unique_ptr new_dot = HloInstruction::CreateDot( + dot->shape(), new_lhs, new_rhs, new_dim_numbers); + return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot)); } // Folds the operands of `convolution` that are foldable transposes. @@ -196,7 +215,7 @@ StatusOr TransposeFolding::Run(HloModule* module) { std::make_pair(instruction, operand_indices)); } } - return tensorflow::Status::OK(); + return Status::OK(); }; for (auto* comp : module->MakeNonfusionComputations()) { @@ -205,7 +224,8 @@ StatusOr TransposeFolding::Run(HloModule* module) { bool changed = false; for (InstructionOperandsPair& pair : foldable_dots) { - changed |= FoldTransposeIntoDot(pair); + TF_RETURN_IF_ERROR(FoldTransposeIntoDot(pair)); + changed = true; } for (InstructionOperandsPair& pair : foldable_convolutions) { changed |= FoldTransposeIntoConvolution(pair); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 0319109f7fc54c6abfe7627bcaff747bade90f41..f73f1227aaf1630a9e7c43bb508732c5518ef929 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -31,9 +32,12 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { @@ -54,83 +58,102 @@ class TransposeFoldingTest : public HloTestBase { }; TEST_F(TransposeFoldingTest, FoldDotTranspose) { - auto builder = HloComputation::Builder("entry_computation"); - HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"x")); - HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"y")); - HloInstruction* transpose_y = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, - /*rhs=*/transpose_y, dot_dnums)); + string hlo_string = R"( +HloModule FoldDotTranspose + +ENTRY entry_computation { + x = f32[2,3]{1,0} parameter(0) + y = f32[2,3]{1,0} parameter(1) + transpose = f32[3,2]{1,0} transpose(y), dimensions={1,0} + ROOT dot = f32[2,2]{1,0} dot(x, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); - auto module = CreateNewModule("test_module"); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build(dot)); FoldTranspose(module.get()); - // Instructions after folding: x, y, and the fusion. - std::unordered_set instruction_set( - entry_computation->instructions().begin(), - entry_computation->instructions().end()); - CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; - CHECK_EQ(1, instruction_set.size()) - << "entry_computation should contain exactly 3 instructions."; - HloInstruction* fusion = *instruction_set.begin(); - EXPECT_EQ(HloOpcode::kFusion, fusion->opcode()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)); +} + +TEST_F(TransposeFoldingTest, DontFoldTransposeOfBatchDim) { + string hlo_string = R"( +HloModule FoldDotTranspose - // The fusion instruction should contain two parameters, one transpose and - // one dot. - EXPECT_EQ(4, fusion->fused_instruction_count()); +ENTRY entry_computation { + x = f32[2,3] parameter(0) + y = f32[3,2] parameter(1) + transpose = f32[2,3] transpose(y), dimensions={1,0} + ROOT dot = f32[2] dot(x, transpose), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + [](const HloInstruction& convolution, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(TransposeFoldingTest, DontFoldTransposeOfRank1Dot) { + string hlo_string = R"( +HloModule FoldDotTranspose + +ENTRY entry_computation { + x = f32[3] parameter(0) + y = f32[3,2] parameter(1) + transpose = f32[2,3] transpose(y), dimensions={1,0} + ROOT dot = f32[2] dot(x, transpose), lhs_batch_dims={}, rhs_batch_dims={0}, lhs_contracting_dims={0}, rhs_contracting_dims={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + [](const HloInstruction& convolution, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + EXPECT_FALSE(changed); } TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { - auto builder = HloComputation::Builder("entry_computation"); - // 2x1 - HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{1}, {2}}))); - // 3x2 - HloInstruction* const1 = - builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); - HloInstruction* transpose0 = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {1, 2}), const0, {1, 0})); - HloInstruction* transpose1 = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {2, 3}), const1, {1, 0})); - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( - ShapeUtil::MakeShape(F32, {1, 3}), - /*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums)); + string hlo_string = R"( +HloModule FoldDotTransposeConstant + +ENTRY entry_computation { + constant = f32[2,1]{1,0} constant(f32[2,1] { { 1 }, { 2 } }) + transpose = f32[1,2]{1,0} transpose(constant), dimensions={1,0} + constant.1 = f32[3,2]{1,0} constant(f32[3,2] { { 1, 2 }, { 3, 4 }, { 5, 6 } }) + transpose.1 = f32[2,3]{1,0} transpose(constant.1), dimensions={1,0} + ROOT dot = f32[1,3]{1,0} dot(transpose, transpose.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); - auto module = CreateNewModule("test_module"); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build(dot)); FoldTranspose(module.get()); - for (auto* instruction : entry_computation->instructions()) { - if (instruction->opcode() == HloOpcode::kFusion) { - CHECK_EQ(2, instruction->operand_count()); - EXPECT_EQ(const0, instruction->operand(0)); - EXPECT_EQ(const1, instruction->operand(1)); - } - } - - // The created fusion instruction should contain two parameters, two - // transposes (one for each parameter) and one dot. - EXPECT_EQ(5, - entry_computation->root_instruction()->fused_instruction_count()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Dot(op::Constant(), op::Constant(), + /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/1)); } TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { @@ -164,50 +187,32 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { EXPECT_EQ(6, callee_computation->instruction_count()); } -TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { - auto builder = HloComputation::Builder("entry_computation"); - HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"x")); - HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"y")); - HloInstruction* transpose_y = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, - /*rhs=*/transpose_y, dot_dnums)); - - auto module = CreateNewModule("test_module"); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build(dot)); +TEST_F(TransposeFoldingTest, FoldDotTransposeInCall) { + string hlo_string = R"( +HloModule FoldDotTransposeInCall - HloInstruction* call = module->OutlineExpressionFromComputation( - {transpose_y, dot}, "outlined", entry_computation); +callee { + name.0 = f32[2,3]{1,0} parameter(0) + name.1 = f32[2,3]{1,0} parameter(1) + transpose.clone = f32[3,2]{1,0} transpose(name.0), dimensions={1,0} + ROOT dot.clone = f32[2,2]{1,0} dot(name.1, transpose.clone), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +ENTRY entry_computation { + y = f32[2,3]{1,0} parameter(1) + x = f32[2,3]{1,0} parameter(0) + ROOT call = f32[2,2]{1,0} call(y, x), to_apply=callee +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); FoldTranspose(module.get()); - // Instructions after folding: x, y, and the fusion. - std::unordered_set instruction_set( - entry_computation->instructions().begin(), - entry_computation->instructions().end()); - CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(call)) - << "call is not in entry_computation."; - CHECK(instruction_set.empty()) - << "entry_computation should contain exactly 3 instructions."; - HloInstruction* fusion = - call->called_computations().front()->root_instruction(); - EXPECT_EQ(HloOpcode::kFusion, fusion->opcode()); - - // The fusion instruction should contain two parameters, one transpose and - // one dot. - EXPECT_EQ(4, fusion->fused_instruction_count()); + const HloComputation* callee = module->GetComputationWithName("callee"); + ASSERT_NE(callee, nullptr); + EXPECT_THAT(callee->root_instruction(), + op::Dot(op::Parameter(1), op::Parameter(0), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)); } // Test that a two dimension swap of the kernel gets folded into convolution. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 657a8fe09ae9df906d695f7f49df72500d611792..8cb654493ca82dc702b2c1e7a4284f4f31d1e5f9 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -588,4 +588,201 @@ void TuplePointsToAnalysis::InstructionToString( }); } +bool TuplePointsToAnalysis::DoesNotUseOperandBuffer( + const HloInstruction* operand, const ShapeIndex& index, + const HloInstruction* user) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { + // GetTupleElement instructions only access the top-level buffer of their + // operand. + return true; + } else if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + // Find fusion parameter associated with 'operand'. + auto it = std::find_if( + user->fused_parameters().begin(), user->fused_parameters().end(), + [=](HloInstruction* fused_param) { + return user->operand(fused_param->parameter_number()) == operand; + }); + CHECK(it != user->fused_parameters().end()); + // Iterate through all users of all buffer aliases of the buffer in the + // points-to set of fusion parameter at 'index'. + // Return false if any uses are detected at 'index', returns true otherwise. + const LogicalBuffer* buffer = GetBufferDefinedAt(*it, index).ValueOrDie(); + for (const BufferAlias& alias : GetBufferAliases(*buffer)) { + for (HloInstruction* alias_user : alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), + alias_user)) { + continue; + } + // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'. + return false; + } + } + // Return true: found no uses of 'operand' at 'index' in 'user'. + return true; + } + return false; +} + +// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. +// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) +// where 'user' is a user of an alias of 'instruction' at 'index', and +// 'operand_index' is the operand index at which the alias appears in the +// operand list of 'user'. +std::vector> +TuplePointsToAnalysis::GetAllUsesOfInstructionAtIndex( + HloInstruction* instruction, const ShapeIndex& index) const { + std::vector> uses; + const PointsToSet::BufferList& points_to = + GetPointsToSet(instruction).element(index); + for (const LogicalBuffer* buffer : points_to) { + for (const BufferAlias& alias : GetBufferAliases(*buffer)) { + for (HloInstruction* alias_user : alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), + alias_user)) { + continue; + } + for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) { + uses.emplace_back(alias_user, op_idx); + } + } + } + } + return uses; +} + +// Returns true if there is exactly one use of 'operand' at 'operand_index' +// in 'fusion.fused_instructions', where the singleton use is the fused +// root at operand index 'use_operand_index'. Returns false otherwise. +// +// REQUIRES: 'fusion' opcode is a kFusion instruction. +bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* fusion, const int64 use_operand_index) const { + CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); + // Check that 'operand' is unique in the operand list of 'fusion'. + if (fusion->OperandIndices(operand).size() > 1) { + return false; + } + // Find fusion parameter associated with 'operand'. + const auto& fused_params = fusion->fused_parameters(); + auto fused_param_it = std::find_if( + fused_params.begin(), fused_params.end(), + [&](HloInstruction* fused_param) { + return fusion->operand(fused_param->parameter_number()) == operand; + }); + if (fused_param_it == fused_params.end()) { + return false; + } + auto* fused_param = *fused_param_it; + // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'. + auto fused_param_uses = + GetAllUsesOfInstructionAtIndex(fused_param, operand_index); + // Return true iff there is exactly one use of 'operand' at 'index', and + // this singleton use is the fused root (at index in 'use_operand_indices'). + return fused_param_uses.size() == 1 && + fused_param_uses[0].first == fusion->fused_expression_root() && + fused_param_uses[0].second == use_operand_index; +} + +// User and operand can share buffers iff both instructions emit the same shape +// and layout, and 'user' meets one of the following qualifications: +// +// (1) Is element-wise. Or... +// (2) Is a loop fusion instruction where the only use of 'operand' at 'index' +// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root +// at operand 0. Or... +// (3) Is a kDot -> kAdd output fusion instruction where the only use of +// 'operand' at 'index' in the set 'user.fused_instructions' is a kAdd fused +// root at operand 0 or 1. Or... +// (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index +// 0. +// +// (2) and (3) can only be determined if points-to analysis is available. +bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* user, const ShapeIndex& user_index) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + const Shape& operand_subshape = + ShapeUtil::GetSubshape(operand->shape(), operand_index); + const Shape& user_subshape = + ShapeUtil::GetSubshape(user->shape(), user_index); + // Check that operand and user emit the same shape and layout. + if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { + return false; + } + if (user->opcode() == HloOpcode::kFusion) { + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && + user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0); + } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + user->fused_expression_root()->opcode() == HloOpcode::kAdd) { + // Output fusion with kAdd fused root. + + // Check if one operand of kAdd fused root is kDot or kConvolution. + auto* add = user->fused_expression_root(); + auto add_operand_it = + std::find_if(add->operands().begin(), add->operands().end(), + [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot; + }); + if (add_operand_it == add->operands().end()) { + return false; + } + auto* matched_add_operand = *add_operand_it; + // Calculate operand index of 'add' operand which was not matched above. + const int64 other_add_operand_index = + matched_add_operand == add->operand(0) ? 1 : 0; + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root (at operand + // index 'other_add_operand_index'). + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, + other_add_operand_index); + } + } + if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kWhile) { + // We eliminated other users in BufferLiveness::live_range_strictly_before, + // so here we just need to check that the use is at operand index 0. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && operand_indices[0] == 0; + } + if (user->opcode() == HloOpcode::kCall) { + // TODO(b/62548313): Remove when buffer assignment is module scoped and + // does not assign buffers to calls. + // Find called computation parameter associated with 'operand'. + const std::vector operand_indices = user->OperandIndices(operand); + if (operand_indices.size() > 1) { + return false; + } + CHECK_EQ(1, operand_indices.size()); + auto* param = user->to_apply()->parameter_instruction(operand_indices[0]); + // Get all uses of 'operand' at 'index' in called computation. + auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index); + + // Return true iff: + // *) There exists exactly one use of 'operand' in called computation. + // *) The unique use is by the root instruction of called computation. + // (Note: we check the root of the called computation, because the + // root result buffer is required to alias with the Call result buffer). + // *) The root instruction of the called computation is element-wise on + // 'operand'. + auto* callee_root = user->to_apply()->root_instruction(); + return param_uses.size() == 1 && param_uses[0].first == callee_root && + callee_root->IsElementwiseOnOperand(param_uses[0].second); + } + // Check if 'user' is element-wise. + return user->IsElementwise(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index c3743b150168ebcf1051050dc511e50c43108c4f..1ac713013650d807b15e33565e6d2dec406a5d13 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -256,6 +256,23 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { string ToString() const; + // Returns true if 'user' cannot possibly use the buffer at 'index' in + // 'operand'. Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user) const; + + // Returns true if 'user' (at 'user_index') can share a buffer with its + // operand 'operand' (at 'operand_index'). Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool CanShareOperandBufferWithUser(HloInstruction* operand, + const ShapeIndex& operand_index, + HloInstruction* user, + const ShapeIndex& user_index) const; + private: explicit TuplePointsToAnalysis( const HloModule* module, @@ -310,6 +327,13 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { return &per_instruction_[id]; } + std::vector> GetAllUsesOfInstructionAtIndex( + HloInstruction* instruction, const ShapeIndex& index) const; + bool HasUniqueFusedUseOfOperandAt(HloInstruction* operand, + const ShapeIndex& operand_index, + HloInstruction* fusion, + const int64 use_operand_index) const; + // The module this analysis is performed on. const HloModule* module_; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index dec446d4dac650ba43992f7870764eedc80cb2cf..f558316b05b168a6f100e8ef69adfd9dbc023102 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -805,5 +805,348 @@ TEST_F(FusionPointsToAnalysisTest, FusionParam0TwoUsers) { Run(/*add_additional_gte0_user=*/true); } +class PointsToAnalysisTestBase : public HloTestBase { + protected: + void BuildModule(std::unique_ptr computation) { + module_ = CreateNewModule(); + computation_ = module_->AddEntryComputation(std::move(computation)); + } + + void RunAnalysis() { + CHECK_NOTNULL(module_.get()); + points_to_analysis_ = + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); + } + + void BuildModuleAndRunAnalysis(std::unique_ptr computation) { + BuildModule(std::move(computation)); + RunAnalysis(); + } + + std::unique_ptr module_; + HloComputation* computation_ = nullptr; + std::unique_ptr points_to_analysis_; +}; + +class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {}; + +TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { + auto builder = HloComputation::Builder(TestName()); + + Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1)); + builder.AddInstruction( + HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // GetTupleElement instructions only access the top-level buffer of their + // operand. + EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {0}, gte0)); + EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {1}, gte1)); + EXPECT_FALSE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte0)); + EXPECT_FALSE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte1)); +} + +TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction never uses tuple element 0, but does use element 1. + EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion)); + EXPECT_FALSE( + points_to_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); +} + +class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {}; + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape in_shape = ShapeUtil::MakeShape(F32, {8}); + Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, in_shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, in_shape, "param1")); + auto result = builder.AddInstruction( + HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param0, {}, + result, {})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param1, {}, + result, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(exp, {}, copy, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction can share with tuple element 1. + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(tuple, {0}, + fusion, {})); + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(tuple, {1}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {4}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto update = builder.AddInstruction( + HloInstruction::CreateParameter(1, update_shape, "update")); + auto starts = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "starts")); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, data, update, starts)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The DynamicUpdateSlice instruction can share with the data operand, but not + // with update or starts. + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {})); + EXPECT_FALSE( + points_to_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {})); + EXPECT_FALSE( + points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto a = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + auto b = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto add_operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, dot, add_operand)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, dot}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused dot add should be able to share buffer with 'add_operand'. + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser( + add_operand, {}, fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape, operand, {0, 1})); + + auto two = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, two, reverse}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused operand->reverse->add cannot alias operand buffer 'operand'. + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(operand, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + + auto make_cond = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Cond"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); + return builder.Build(); + }; + + auto make_body = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data)); + return builder.Build(); + }; + + module_ = CreateNewModule(); + HloComputation* cond_computation = + module_->AddEmbeddedComputation(make_cond()); + HloComputation* body_computation = + module_->AddEmbeddedComputation(make_body()); + + auto builder = HloComputation::Builder(TestName()); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto whil = builder.AddInstruction(HloInstruction::CreateWhile( + data_shape, cond_computation, body_computation, data)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + // The While instruction can share with the data operand. + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(data, {}, whil, {})); +} + +// Tests that Call can alias operand buffer if the only use of the operand +// in the called computation is an elementwise instruction. +TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { + Shape shape = ShapeUtil::MakeShape(F32, {8}); + // Build sub-computation with fusion root. + auto sub_builder = HloComputation::Builder(TestName() + "_sub"); + auto sub_param = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "sub_param")); + auto one = sub_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto ones = sub_builder.AddInstruction( + HloInstruction::CreateBroadcast(shape, one, {1})); + auto add = sub_builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); + + module_ = CreateNewModule(); + auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); + sub_computation->CreateFusionInstruction({add, ones}, + HloInstruction::FusionKind::kLoop); + + // Build entry-computation with kCall which calls 'sub_computation'. + auto builder = HloComputation::Builder(TestName()); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto reverse = + builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0})); + auto call = builder.AddInstruction( + HloInstruction::CreateCall(shape, {reverse}, sub_computation)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(reverse, {}, + call, {})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.cc b/tensorflow/compiler/xla/service/tuple_simplifier.cc index 113c2e2bd9f73a2b0c783103d7f2da9534bc97c3..d668855084a884518b338cdf396a9330b9f43a2b 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier.cc @@ -69,6 +69,7 @@ StatusOr TupleSimplifier::Run(HloModule* module) { // Tuple // HloInstruction* top_tuple = nullptr; + HloInstruction* first_gte = nullptr; bool can_simplify = true; for (int64 operand_number = 0; operand_number < instruction->operand_count(); ++operand_number) { @@ -78,11 +79,17 @@ StatusOr TupleSimplifier::Run(HloModule* module) { can_simplify = false; break; } - + if (first_gte == nullptr) { + first_gte = operand; + } else if (!first_gte->has_compatible_sharding(operand)) { + can_simplify = false; + break; + } if (top_tuple == nullptr) { top_tuple = operand->mutable_operand(0); if (!ShapeUtil::Compatible(top_tuple->shape(), - instruction->shape())) { + instruction->shape()) || + !instruction->has_compatible_sharding(top_tuple)) { can_simplify = false; break; } @@ -108,15 +115,17 @@ StatusOr TupleSimplifier::Run(HloModule* module) { // | // GTE if (instruction->operand(0)->opcode() == HloOpcode::kTuple) { - changed = true; HloInstruction* element_source = instruction->mutable_operand(0)->mutable_operand( instruction->tuple_index()); - TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source)); - for (HloInstruction* user : element_source->users()) { - if (user->opcode() == HloOpcode::kTuple || - user->opcode() == HloOpcode::kGetTupleElement) { - worklist.push(user); + if (instruction->has_compatible_sharding(element_source)) { + changed = true; + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source)); + for (HloInstruction* user : element_source->users()) { + if (user->opcode() == HloOpcode::kTuple || + user->opcode() == HloOpcode::kGetTupleElement) { + worklist.push(user); + } } } } diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 0f16a592b68e20f5dbd1e4655ad5720ecce5a7bd..9e62d0acfb98946f1e693fc0310098b4ec99750b 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -55,6 +55,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { return HloOpcode::kCos; case UNOP_EXP: return HloOpcode::kExp; + case UNOP_EXPM1: + return HloOpcode::kExpm1; case UNOP_FLOOR: return HloOpcode::kFloor; case UNOP_IMAG: @@ -63,6 +65,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { return HloOpcode::kIsFinite; case UNOP_LOG: return HloOpcode::kLog; + case UNOP_LOG1P: + return HloOpcode::kLog1p; case UNOP_NOT: return HloOpcode::kNot; case UNOP_NEGATE: diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 5b44c26b7c7b082556d9533cf3b3b1b98e5e4b09..141347a792c23a2c542d7b564ab76c118409865d 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -16,8 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_ +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -31,99 +32,93 @@ class ServiceInterface { virtual ~ServiceInterface() = default; // TODO(b/31824348): Convert to use StatusOr. - virtual tensorflow::Status TransferToClient( - const TransferToClientRequest* arg, TransferToClientResponse* result) = 0; + virtual Status TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) = 0; - virtual tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, TransferToServerResponse* result) = 0; + virtual Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) = 0; - virtual tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, TransferToInfeedResponse* result) = 0; + virtual Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) = 0; - virtual tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) = 0; + virtual Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) = 0; - virtual tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) = 0; + virtual Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) = 0; - virtual tensorflow::Status LoadComputationSnapshot( + virtual Status LoadComputationSnapshot( const LoadComputationSnapshotRequest* request, LoadComputationSnapshotResponse* result) = 0; - virtual tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) = 0; + virtual Status Execute(const ExecuteRequest* arg, + ExecuteResponse* result) = 0; - virtual tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) = 0; + virtual Status ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) = 0; - virtual tensorflow::Status ExecuteParallel( - const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) = 0; + virtual Status ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) = 0; - virtual tensorflow::Status ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, - ExecuteParallelResponse* result) = 0; + virtual Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) = 0; - virtual tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) = 0; + virtual Status ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) = 0; - virtual tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) = 0; + virtual Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) = 0; - virtual tensorflow::Status DeconstructTuple( - const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) = 0; + virtual Status DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) = 0; - virtual tensorflow::Status GetComputationStats( - const ComputationStatsRequest* arg, ComputationStatsResponse* result) = 0; + virtual Status GetComputationStats(const ComputationStatsRequest* arg, + ComputationStatsResponse* result) = 0; - virtual tensorflow::Status GetComputationGraphStats( + virtual Status GetComputationGraphStats( const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) = 0; - virtual tensorflow::Status GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) = 0; + virtual Status GetComputationShape(const GetComputationShapeRequest* arg, + GetComputationShapeResponse* result) = 0; - virtual tensorflow::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) = 0; + virtual Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) = 0; - virtual tensorflow::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) = 0; + virtual Status CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) = 0; - virtual tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) = 0; + virtual Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) = 0; // Methods used by ComputationBuilder. - virtual tensorflow::Status Computation(const ComputationRequest* arg, - ComputationResponse* result) = 0; + virtual Status Computation(const ComputationRequest* arg, + ComputationResponse* result) = 0; - virtual tensorflow::Status Op(const OpRequest* arg, OpResponse* result) = 0; + virtual Status Op(const OpRequest* arg, OpResponse* result) = 0; - virtual tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) = 0; + virtual Status GetLocalShape(const GetLocalShapeRequest* arg, + GetLocalShapeResponse* result) = 0; - virtual tensorflow::Status SetReturnValue( - const SetReturnValueRequest* arg, SetReturnValueResponse* results) = 0; + virtual Status SetReturnValue(const SetReturnValueRequest* arg, + SetReturnValueResponse* results) = 0; - virtual tensorflow::Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) = 0; + virtual Status IsConstant(const IsConstantRequest* arg, + IsConstantResponse* result) = 0; - virtual tensorflow::Status ComputeConstant( - const ComputeConstantRequest* arg, ComputeConstantResponse* result) = 0; + virtual Status ComputeConstant(const ComputeConstantRequest* arg, + ComputeConstantResponse* result) = 0; - virtual tensorflow::Status ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, - ComputeConstantResponse* result) = 0; + virtual Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) = 0; // Methods used by Computation. - virtual tensorflow::Status SnapshotComputation( - const SnapshotComputationRequest* ag, - SnapshotComputationResponse* result) = 0; + virtual Status SnapshotComputation(const SnapshotComputationRequest* ag, + SnapshotComputationResponse* result) = 0; // Methods used by GlobalData. - virtual tensorflow::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) = 0; + virtual Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) = 0; }; } // namespace xla diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index 789eba5780d37e1fd4d80ec881855951c8bba0eb..7ee366b27a82bdbcb7a63a57ea80194db8ca7df4 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -22,24 +22,24 @@ limitations under the License. namespace xla { -tensorflow::Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { +Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { if (!ShapeUtil::Compatible(other_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", ShapeUtil::HumanString(other_shape).c_str(), ShapeUtil::HumanString(shape()).c_str()); } shape_ = other_shape; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { +Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { if (!ShapeUtil::Compatible(*to_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", ShapeUtil::HumanString(*to_shape).c_str(), ShapeUtil::HumanString(shape()).c_str()); } *to_shape = shape_; - return tensorflow::Status::OK(); + return Status::OK(); } void ShapeLayout::SetToDefaultLayout() { diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h index a1dce758cd3ab3f204ce330fca2a7d2bdf57a2be..36806da599cc9b27286e67c128bb7f496f29c105 100644 --- a/tensorflow/compiler/xla/shape_layout.h +++ b/tensorflow/compiler/xla/shape_layout.h @@ -40,7 +40,7 @@ class ShapeLayout { // Assigns the layouts in this ShapeLayout to the Layout fields of the given // shape. 'to_shape' and the shape of the ShapeLayout object must be // compatible. - tensorflow::Status AssignLayoutToShape(Shape* to_shape) const; + Status AssignLayoutToShape(Shape* to_shape) const; // Returns true if the Layouts in this ShapeLayout match the layouts in the // given shape. Returns false otherwise. If the given shape is not compatible @@ -49,7 +49,7 @@ class ShapeLayout { // Copies the layout from the given shape into this ShapeLayout. 'other_shape' // must be compatible with the ShapeLayout's shape. - tensorflow::Status CopyLayoutFromShape(const Shape& other_shape); + Status CopyLayoutFromShape(const Shape& other_shape); // Clears (Layout::Clear) all the Layouts stored in this object. void Clear(); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index ffaa40c2d673a2365342371ed8dab59565d1d08f..37c94ac543b166c14affd8165d244440ae6b67d6 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -42,36 +42,20 @@ namespace internal { template struct ShapeTreeNode { // Data corresponding to this node. - T data; + std::pair data; - // Children of this node. - std::vector> children; + // Children of this node, as indices into the container's nodes_ array. + std::vector children; - ShapeTreeNode() = default; - explicit ShapeTreeNode(const T& data) : data(data) {} - - ShapeTreeNode(const ShapeTreeNode& other) - : data(other.data), children(other.children.size()) { - for (size_t i = 0; i < children.size(); ++i) { - children[i] = ::xla::MakeUnique(*other.children[i]); - } - } - - ShapeTreeNode& operator=(const ShapeTreeNode& other) { - if (this != &other) { - data = other.data; - children.resize(other.children.size()); - for (size_t i = 0; i < children.size(); ++i) { - children[i] = ::xla::MakeUnique(*other.children[i]); - } - } - return *this; - } + explicit ShapeTreeNode(ShapeIndex index) + : ShapeTreeNode(std::move(index), T()) {} + ShapeTreeNode(ShapeIndex index, T data) + : data(std::move(index), std::move(data)) {} }; } // namespace internal -template +template class ShapeTreeIterator; // A ShapeTree is a recursive data structure which mirrors the structure of a @@ -95,10 +79,9 @@ class ShapeTreeIterator; // before its ShapeTree goes away. template class ShapeTree { - friend class ShapeTreeIterator; - friend class ShapeTreeIterator; - public: + using Node = internal::ShapeTreeNode; + // Default constructor creates a tree with a nil shape (i.e. an empty tuple). ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} @@ -110,30 +93,12 @@ class ShapeTree { // alive longer than this ShapeTree. explicit ShapeTree(Shape shape); explicit ShapeTree(const Shape* shape); + explicit ShapeTree(const std::shared_ptr& shape); // Create ShapeTree with the given shape, and init_value for all nodes. ShapeTree(Shape shape, const T& init_value); ShapeTree(const Shape* shape, const T& init_value); - - ShapeTree(const ShapeTree& other) { *this = other; } - ShapeTree(ShapeTree&&) = default; - - ShapeTree& operator=(const ShapeTree& other) { - root_ = other.root_; - - // Fix up internal pointer if necessary. - if (other.shape_storage_) { - CHECK_EQ(other.shape_, other.shape_storage_.get()); - shape_storage_.reset(new Shape(*other.shape_)); - shape_ = shape_storage_.get(); - } else { - shape_ = other.shape_; - } - - return *this; - } - - ShapeTree& operator=(ShapeTree&& other) = default; + ShapeTree(const std::shared_ptr& shape, const T& init_value); // Returns the data element associated with the array in the shape at the // given index (see ShapeUtil::GetSubshape for how indexes are defined). @@ -161,63 +126,70 @@ class ShapeTree { return Lookup(index)->children.empty(); } - // iterator implements a forward_iterator with value_type = - // std::pair - using iterator = ShapeTreeIterator; - using const_iterator = ShapeTreeIterator; + ShapeTree(const ShapeTree&) = default; + ShapeTree& operator=(const ShapeTree&) = default; + ShapeTree(ShapeTree&&) = default; + ShapeTree& operator=(ShapeTree&& other) = default; + + // iterator implements a bidirectional_iterator with + // value_type = std::pair. + // + // The iteration order is guaranteed to be a pre-order walk of the ShapeTree. + using iterator = + ShapeTreeIterator, typename std::vector::iterator, + std::pair>; + using const_iterator = + ShapeTreeIterator, + typename std::vector::const_iterator, + const std::pair>; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; // begin/end for iterating over all nodes. iterator begin() { - return iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/false); } iterator end() { - return iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/false); } const_iterator begin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/false); } const_iterator end() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/false); } // rbegin/rend for iterating over all nodes in reverse. - iterator rbegin() { - return iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/true); - } - iterator rend() { - return iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/true); + reverse_iterator rbegin() { return reverse_iterator(end()); } + reverse_iterator rend() { return reverse_iterator(begin()); } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(end()); } - const_iterator rbegin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/true); - } - const_iterator rend() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/true); + const_reverse_iterator rend() const { + return const_reverse_iterator(begin()); } // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no // children). iterator leaf_begin() { - return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/false); + return iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/true); } iterator leaf_end() { - return iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/false); + return iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/true); } const_iterator leaf_begin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/true, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/true); } const_iterator leaf_end() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/true); } // range-based iterator for leaf_begin()/leaf_end(). tensorflow::gtl::iterator_range leaves() { @@ -227,20 +199,27 @@ class ShapeTree { return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); } - iterator leaf_rbegin() { - return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/true); + reverse_iterator leaf_rbegin() { return reverse_iterator(leaf_end()); } + reverse_iterator leaf_rend() { return reverse_iterator(leaf_begin()); } + const_reverse_iterator leaf_rbegin() const { + return const_reverse_iterator(leaf_end()); } - iterator leaf_rend() { - return iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/true); + const_reverse_iterator leaf_rend() const { + return const_reverse_iterator(leaf_begin()); } - const_iterator leaf_rbegin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/true, - /*reverse=*/true); + + // Returns an iterator pointing to the given ShapeIndex. + // REQUIRES: index must exist in the ShapeTree. + iterator find(const ShapeIndex& index) { + Node* element = Lookup(index); + return iterator(&nodes_, typename std::vector::iterator(element), + /*iterate_leaves_only=*/false); } - const_iterator leaf_rend() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/true); + const_iterator find(const ShapeIndex& index) const { + Node* element = Lookup(index); + return iterator(&nodes_, + typename std::vector::const_iterator(element), + /*iterate_leaves_only=*/false); } // Recursively traverses the shape and calls the given function at each @@ -282,8 +261,6 @@ class ShapeTree { bool operator!=(const ShapeTree& other) const { return !(*this == other); } private: - using Node = internal::ShapeTreeNode; - // Initialize node->children based on 'shape'. All children are assigned the // the given 'init_value'. void InitChildren(const Shape& shape, const T& init_value, Node* node); @@ -292,136 +269,57 @@ class ShapeTree { // default-constructed data values. void InitChildren(const Shape& shape, Node* node); + // Returns the number of subshapes, including interior nodes, in shape. + int64 CountSubshapes(const Shape& shape); + // Helpers for traversing the shape via ForEachElement. The helpers // recursively traverse the subtree rooted at "index" (defined as in // ShapeUtil::GetSubshape). template - static Status ForEachHelper(const Fn& func, const Node& node, - ShapeIndex* index); + static Status ForEachHelper(const Fn& func, const std::vector& nodes); template - static Status ForEachMutableHelper(const Fn& func, Node* node, - ShapeIndex* index); + static Status ForEachMutableHelper(const Fn& func, std::vector* nodes); // Return the tree node at the given index. Node* Lookup(const ShapeIndex& index); const Node* Lookup(const ShapeIndex& index) const; - // The root node, which contains all other nodes. - Node root_; + // The nodes in this shape tree. + std::vector nodes_; // If we own our Shape, this field contains it, and shape_ is a pointer into // here. Otherwise if we don't own our shape, this is nullptr. - std::unique_ptr shape_storage_; + std::shared_ptr shape_storage_; // The XLA shape mirrored in this ShapeTree. This is either // shape_storage_.get() or the Shape pointer passed to our constructor. const Shape* shape_; }; -// Internal iterator that performs a pre-order walk. This is copyable, but -// contains a vector so isn't cheap to copy. This also means post-increment is -// expensive. The iterator value_type is equivalent to a std::pair, similar to std::map. The non-const iterator's T& type can be mutated -// in-place. -template -class ShapeTreeIterator : public std::iterator> { +// Internal iterator that performs a pre-order walk. This is cheap to copy. +// The iterator value_type is equivalent to a +// std::pair&, similar to std::map. +template +class ShapeTreeIterator + : public std::iterator { public: - using value_type = - typename std::conditional, - std::pair>::type; - using NodeType = - typename std::conditional::Node, - typename ShapeTree::Node>::type; - - // Construct an iterator pointing at node. Node must either be the tree root - // or nullptr (which is equivalent to end() and should not be dereferenced or - // incremented). If iterate_leaves_only is true, the iterator will not include - // interior tree nodes, only leaves. If reverse is true, the iterator will - // visit nodes in the reverse of pre-order traversal. - ShapeTreeIterator(NodeType* node, bool iterate_leaves_only, bool reverse) - : node_(node), - iterate_leaves_only_(iterate_leaves_only), - reverse_(reverse) { - if (node_) { - if (reverse_) { - while (!node_->children.empty()) { - const int child_index = node_->children.size() - 1; - stack_.push_back({node_, child_index}); - node_ = node_->children[child_index].get(); - } - } else { - if (!node_->children.empty() && iterate_leaves_only) { - ++*this; - } - } + ShapeTreeIterator(ContainerType* nodes, IteratorType node, + bool iterate_leaves_only) + : nodes_(nodes), + node_(std::move(node)), + iterate_leaves_only_(iterate_leaves_only) { + while (iterate_leaves_only && node_ != nodes_->end() && + !node_->children.empty()) { + ++node_; } } - ShapeTreeIterator(const ShapeTreeIterator& other) - : node_(other.node_), - stack_(other.stack_), - iterate_leaves_only_(other.iterate_leaves_only_), - reverse_(other.reverse_) {} ShapeTreeIterator& operator++() { - CHECK_NE(nullptr, node_) << "walking off the end() of an iterator!"; - if (reverse_) { - while (!stack_.empty()) { - node_ = stack_.back().first; - int64 next_child_index = stack_.back().second - 1; - stack_.pop_back(); - if (next_child_index < 0) { - if (!iterate_leaves_only_) { - // All children are visited, yield . - return *this; - } - } else { - stack_.push_back({node_, next_child_index}); - node_ = node_->children[next_child_index].get(); - while (!node_->children.empty()) { - const int child_index = node_->children.size() - 1; - stack_.push_back({node_, child_index}); - node_ = node_->children[child_index].get(); - } - return *this; - } - } - } else { - // We're doing a pre-order walk, so if our current node has children take - // the first child. - if (!node_->children.empty()) { - stack_.push_back({node_, /*child-index=*/0}); - node_ = node_->children[0].get(); - if (node_->children.empty() || !iterate_leaves_only_) { - return *this; - } else { - // This is a non-leaf; tail-recurse. - return ++(*this); - } - } - // Otherwise we are currently at a leaf. Walk back up until a node - // contains a child we haven't visited yet. - while (!stack_.empty()) { - node_ = stack_.back().first; - int64 next_child_index = stack_.back().second + 1; - stack_.pop_back(); - if (node_->children.size() > next_child_index) { - stack_.push_back({node_, next_child_index}); - node_ = node_->children[next_child_index].get(); - - if (node_->children.empty() || !iterate_leaves_only_) { - return *this; - } else { - // This is a non-leaf; tail-recurse. - return ++(*this); - } - } - } + ++node_; + while (iterate_leaves_only_ && node_ != nodes_->end() && + !node_->children.empty()) { + ++node_; } - // We've walked off the end of the tree. Set node_ to nullptr to signify - // end(). - node_ = nullptr; - current_.reset(); return *this; } ShapeTreeIterator operator++(int) { @@ -429,52 +327,62 @@ class ShapeTreeIterator : public std::iterator nodes_->begin() && + !node_->children.empty()) { + --node_; + } + return *this; + } + ShapeTreeIterator operator--(int) { + auto i = *this; + --(*this); + return i; + } + bool operator==(const ShapeTreeIterator& other) const { return node_ == other.node_; } bool operator!=(const ShapeTreeIterator& other) const { return node_ != other.node_; } - value_type& operator*() { return UpdateCurrent(); } - value_type* operator->() { return &UpdateCurrent(); } + ValueType& operator*() { return node_->data; } + ValueType* operator->() { return &node_->data; } private: - // Updates the current_ member to reflect the current state. - value_type& UpdateCurrent() { - ShapeIndex index; - for (auto& node_and_index : stack_) { - index.push_back(node_and_index.second); - } - current_ = ::xla::MakeUnique(index, node_->data); - return *current_; - } - - // The node to which this iterator is pointing. This is the source of truth in - // the iterator - the stack only exists to facilitate walking back from - // children to parents. - NodeType* node_; - // Stack of {node, child-index} pairs of the path taken from the root to get - // to node_. This allows us to backtrack and know where to go next. - std::vector> stack_; + ContainerType* nodes_; + IteratorType node_; // True if we should not include interior nodes in our walk. bool iterate_leaves_only_; - // True if we should yield the reverse of the pre-order traversal. - bool reverse_; - // Placeholder for the current value. Ideally this wouldn't exist and would - // just be an rvalue, but operator -> needs to return a pointer to something. - // We cannot just use a plain old value_type as it contains a reference so - // cannot be default-constructed. - std::unique_ptr current_; }; +template +int64 ShapeTree::CountSubshapes(const Shape& shape) { + int64 current_count = 1; + if (ShapeUtil::IsTuple(shape)) { + int64 count = ShapeUtil::TupleElementCount(shape); + for (int i = 0; i < count; ++i) { + current_count += CountSubshapes(shape.tuple_shapes(i)); + } + } + return current_count; +} + template void ShapeTree::InitChildren(const Shape& shape, const T& init_value, Node* node) { if (ShapeUtil::IsTuple(shape)) { - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - node->children.emplace_back(new Node(init_value)); - InitChildren(shape.tuple_shapes(i), init_value, - node->children.back().get()); + const int64 size = ShapeUtil::TupleElementCount(shape); + node->children.reserve(size); + ShapeIndex shape_index = node->data.first; + shape_index.push_back(0); + for (int i = 0; i < size; ++i) { + shape_index[shape_index.size() - 1] = i; + node->children.push_back(nodes_.size()); + nodes_.emplace_back(shape_index, init_value); + InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back()); } } } @@ -482,63 +390,92 @@ void ShapeTree::InitChildren(const Shape& shape, const T& init_value, template void ShapeTree::InitChildren(const Shape& shape, Node* node) { if (ShapeUtil::IsTuple(shape)) { - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - node->children.emplace_back(new Node()); - InitChildren(shape.tuple_shapes(i), node->children.back().get()); + const int64 size = ShapeUtil::TupleElementCount(shape); + node->children.reserve(size); + ShapeIndex shape_index = node->data.first; + shape_index.push_back(0); + for (int i = 0; i < size; ++i) { + shape_index[shape_index.size() - 1] = i; + node->children.push_back(nodes_.size()); + nodes_.emplace_back(shape_index); + InitChildren(shape.tuple_shapes(i), &nodes_.back()); } } } template ShapeTree::ShapeTree(Shape shape) - : root_(), - shape_storage_(::xla::MakeUnique(std::move(shape))), + : shape_storage_(std::make_shared(std::move(shape))), shape_(shape_storage_.get()) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(shape_storage_.get()); - InitChildren(*shape_, &root_); + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}); + InitChildren(*shape_, &nodes_[0]); +} + +template +ShapeTree::ShapeTree(const Shape* shape) : shape_(shape) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}); + InitChildren(*shape_, &nodes_[0]); } template -ShapeTree::ShapeTree(const Shape* shape) : root_(), shape_(shape) { - InitChildren(*shape_, &root_); +ShapeTree::ShapeTree(const std::shared_ptr& shape) + : shape_storage_(shape), shape_(shape_storage_.get()) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}); + InitChildren(*shape_, &nodes_[0]); } template ShapeTree::ShapeTree(Shape shape, const T& init_value) - : root_(init_value), - shape_storage_(::xla::MakeUnique(std::move(shape))), + : shape_storage_(std::make_shared(std::move(shape))), shape_(shape_storage_.get()) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(shape_storage_.get()); - InitChildren(*shape_, init_value, &root_); + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}, init_value); + InitChildren(*shape_, init_value, &nodes_[0]); } template ShapeTree::ShapeTree(const Shape* shape, const T& init_value) - : root_(init_value), shape_(shape) { - InitChildren(*shape_, init_value, &root_); + : shape_(shape) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}, init_value); + InitChildren(*shape_, init_value, &nodes_[0]); +} + +template +ShapeTree::ShapeTree(const std::shared_ptr& shape, + const T& init_value) + : shape_storage_(shape), shape_(shape_storage_.get()) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}, init_value); + InitChildren(*shape_, init_value, &nodes_[0]); } template const T& ShapeTree::element(const ShapeIndex& index) const { - return Lookup(index)->data; + return Lookup(index)->data.second; } template T* ShapeTree::mutable_element(const ShapeIndex& index) { - return &Lookup(index)->data; + return &Lookup(index)->data.second; } template internal::ShapeTreeNode* ShapeTree::Lookup(const ShapeIndex& index) { - Node* node = &root_; + Node* node = &nodes_[0]; for (const int64 i : index) { CHECK_GE(i, 0); CHECK_LT(i, node->children.size()); - node = node->children[i].get(); + node = &nodes_[node->children[i]]; } return node; } @@ -552,13 +489,10 @@ const internal::ShapeTreeNode* ShapeTree::Lookup( /* static */ template template -Status ShapeTree::ForEachHelper(const Fn& func, const Node& node, - ShapeIndex* index) { - TF_RETURN_IF_ERROR(func(*index, node.data)); - for (int64 i = 0; i < node.children.size(); ++i) { - index->push_back(i); - TF_RETURN_IF_ERROR(ForEachHelper(func, *node.children[i], index)); - index->pop_back(); +Status ShapeTree::ForEachHelper(const Fn& func, + const std::vector& nodes) { + for (const auto& node : nodes) { + TF_RETURN_IF_ERROR(func(node.data.first, node.data.second)); } return Status::OK(); } @@ -566,14 +500,10 @@ Status ShapeTree::ForEachHelper(const Fn& func, const Node& node, /* static */ template template -Status ShapeTree::ForEachMutableHelper(const Fn& func, Node* node, - ShapeIndex* index) { - TF_RETURN_IF_ERROR(func(*index, &node->data)); - for (int64 i = 0; i < node->children.size(); ++i) { - index->push_back(i); - TF_RETURN_IF_ERROR( - ForEachMutableHelper(func, node->children[i].get(), index)); - index->pop_back(); +Status ShapeTree::ForEachMutableHelper(const Fn& func, + std::vector* nodes) { + for (auto& node : *nodes) { + TF_RETURN_IF_ERROR(func(node.data.first, &node.data.second)); } return Status::OK(); } @@ -581,40 +511,36 @@ Status ShapeTree::ForEachMutableHelper(const Fn& func, Node* node, template template Status ShapeTree::ForEachElementWithStatus(const Fn& func) const { - ShapeIndex index; - return ForEachHelper(func, root_, &index); + return ForEachHelper(func, nodes_); } template template Status ShapeTree::ForEachMutableElementWithStatus(const Fn& func) { - ShapeIndex index; - return ForEachMutableHelper(func, &root_, &index); + return ForEachMutableHelper(func, &nodes_); } template template void ShapeTree::ForEachElement(const Fn& func) const { - ShapeIndex index; return ForEachHelper( [&func](const ShapeIndex& index, const T& data) { func(index, data); return Status::OK(); }, - root_, &index) + nodes_) .IgnoreError(); } template template void ShapeTree::ForEachMutableElement(const Fn& func) { - ShapeIndex index; return ForEachMutableHelper( [&func](const ShapeIndex& index, T* data) { func(index, data); return Status::OK(); }, - &root_, &index) + &nodes_) .IgnoreError(); } diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index 4b6ab772811f4a6c6ffc1d10befc7122f883b8f9..dc5facf1581c07fbb74dfcee95025692938632bd 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace xla { namespace { @@ -421,8 +422,8 @@ TEST_F(ShapeTreeTest, IterateAndMutate) { } ++i; } - t.begin()->second = 78; - EXPECT_EQ(78, t.begin()->second); + (*t.begin()).second = 78; + EXPECT_EQ(78, (*t.begin()).second); i = 0; for (auto& index_to_data : t) { if (i == 0) { @@ -434,14 +435,14 @@ TEST_F(ShapeTreeTest, IterateAndMutate) { } ++i; } - EXPECT_EQ(78, t.begin()->second); - EXPECT_EQ(98, std::next(t.begin())->second); + EXPECT_EQ(78, (*t.begin()).second); + EXPECT_EQ(98, (*std::next(t.begin())).second); } TEST_F(ShapeTreeTest, IterateOrder) { ShapeTree t(nested_tuple_shape_, 42); std::vector v; - for (auto& index_to_data : t) { + for (auto index_to_data : t) { v.push_back(index_to_data.first); } EXPECT_EQ(v, (std::vector{{}, @@ -479,7 +480,7 @@ TEST_F(ShapeTreeTest, ReverseIterateOrder) { TEST_F(ShapeTreeTest, IterateOrderLeaves) { ShapeTree t(nested_tuple_shape_, 42); std::vector v; - for (auto& index_to_data : t.leaves()) { + for (auto index_to_data : t.leaves()) { v.push_back(index_to_data.first); } EXPECT_EQ(v, (std::vector{ @@ -502,5 +503,106 @@ TEST_F(ShapeTreeTest, ReverseIterateOrderLeaves) { })); } +void BM_Construct(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + for (int i = 0; i < iters; ++i) { + ShapeTree shape_tree(shape); + } +} + +void BM_ConstructUnowned(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + for (int i = 0; i < iters; ++i) { + ShapeTree shape_tree(&shape); + } +} + +void BM_Copy(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + ShapeTree copy = shape_tree; + tensorflow::testing::DoNotOptimize(copy); + } +} + +void BM_Move(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + ShapeTree copy = std::move(shape_tree); + shape_tree = std::move(copy); + } +} + +void BM_ForEach(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + shape_tree.ForEachMutableElement([](const ShapeIndex& index, int* data) { + tensorflow::testing::DoNotOptimize(index); + }); + } +} + +void BM_Iterate(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + for (auto& iter : shape_tree) { + tensorflow::testing::DoNotOptimize(iter.second); + } + } +} + +BENCHMARK(BM_Construct)->ArgPair(2, 8); +BENCHMARK(BM_ConstructUnowned)->ArgPair(2, 8); +BENCHMARK(BM_Copy)->ArgPair(2, 8); +BENCHMARK(BM_Move)->ArgPair(2, 8); +BENCHMARK(BM_ForEach)->ArgPair(2, 8); +BENCHMARK(BM_Iterate)->ArgPair(2, 8); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index cb8bf5a2b9e5d06f73e2116ed08630249ae8f970..82c75f85d838f94cb040e56d59d0e012af5e0db0 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -231,7 +231,7 @@ class ShapeUtil { } // Returns the higher-precision element type if a and b are both floating - // point types; otherwise, checks that that they have the same element type + // point types; otherwise, checks that they have the same element type // and returns it. static PrimitiveType HigherPrecisionElementType(const Shape& a, const Shape& b) { diff --git a/tensorflow/compiler/xla/status.h b/tensorflow/compiler/xla/status.h index 4eb3bf3766412d5d9a8e78a4652807c5eaeef6ee..69abb51852ac09e8d357a9ba7924efc348ef2001 100644 --- a/tensorflow/compiler/xla/status.h +++ b/tensorflow/compiler/xla/status.h @@ -21,7 +21,7 @@ limitations under the License. namespace xla { -using tensorflow::Status; +using tensorflow::Status; // TENSORFLOW_STATUS_OK } // namespace xla diff --git a/tensorflow/compiler/xla/statusor.h b/tensorflow/compiler/xla/statusor.h index cccbce5fc83af87396f4d51eb9e785cea93aba0b..0e1387c93938fa520562fcd63ac107a82b089a51 100644 --- a/tensorflow/compiler/xla/statusor.h +++ b/tensorflow/compiler/xla/statusor.h @@ -13,13 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// StatusOr is the union of a Status object and a T -// object. StatusOr models the concept of an object that is either a -// usable value, or an error Status explaining why such a value is -// not present. To this end, StatusOr does not allow its Status -// value to be Status::OK. Furthermore, the value of a StatusOr -// must not be null. This is enforced by a debug check in most cases, -// but even when it is not, clients must not set the value to null. +// StatusOr is the union of a Status object and a T object. StatusOr models +// the concept of an object that is either a value, or an error Status +// explaining why such a value is not present. To this end, StatusOr does not +// allow its Status value to be Status::OK. // // The primary use-case for StatusOr is as the return value of a // function which may fail. diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc index f9d25945bc617507735fb6c4d011c39723497f69..377a618ffbd99316d409130df8a39f352664dee0 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -75,6 +75,14 @@ TEST(StatusOr, ElementType) { static_assert(std::is_same::element_type, char>(), ""); } +TEST(StatusOr, NullPointerStatusOr) { + // As a very special case, null-plain-pointer StatusOr used to be an + // error. Test that it no longer is. + StatusOr null_status(nullptr); + EXPECT_TRUE(null_status.ok()); + EXPECT_EQ(null_status.ValueOrDie(), nullptr); +} + TEST(StatusOr, TestNoDefaultConstructorInitialization) { // Explicitly initialize it with an error code. StatusOr statusor(tensorflow::errors::Cancelled("")); @@ -405,7 +413,7 @@ TEST(StatusOr, TestPointerValueConst) { EXPECT_EQ(&kI, thing.ValueOrDie()); } -// NOTE(tucker): tensorflow::StatusOr does not support this kind +// NOTE(tucker): StatusOr does not support this kind // of resize op. // TEST(StatusOr, StatusOrVectorOfUniquePointerCanResize) { // using EvilType = std::vector>; diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h index 17bae2e4f611268df824ce793c75ba1c95573455..8918350135fbb86973b228b35f5873fea8695b2f 100644 --- a/tensorflow/compiler/xla/test_helpers.h +++ b/tensorflow/compiler/xla/test_helpers.h @@ -40,13 +40,10 @@ class Literal; namespace testing { namespace internal_status { -inline const ::tensorflow::Status& GetStatus( - const ::tensorflow::Status& status) { - return status; -} +inline const Status& GetStatus(const Status& status) { return status; } template -inline const ::tensorflow::Status& GetStatus(const StatusOr& status) { +inline const Status& GetStatus(const StatusOr& status) { return status.status(); } } // namespace internal_status @@ -57,21 +54,17 @@ inline const ::tensorflow::Status& GetStatus(const StatusOr& status) { // The following macros are similar to macros in gmock, but deliberately named // differently in order to avoid conflicts in files which include both. -// Macros for testing the results of functions that return tensorflow::Status or +// Macros for testing the results of functions that return Status or // StatusOr (for any type T). -#define EXPECT_IS_OK(expression) \ - EXPECT_EQ(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) -#define EXPECT_IS_NOT_OK(expression) \ - EXPECT_NE(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) +#define EXPECT_IS_OK(expression) \ + EXPECT_EQ(Status::OK(), xla::testing::internal_status::GetStatus(expression)) +#define EXPECT_IS_NOT_OK(expression) \ + EXPECT_NE(Status::OK(), xla::testing::internal_status::GetStatus(expression)) #undef ASSERT_IS_OK -#define ASSERT_IS_OK(expression) \ - ASSERT_EQ(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) +#define ASSERT_IS_OK(expression) \ + ASSERT_EQ(Status::OK(), xla::testing::internal_status::GetStatus(expression)) #undef ASSERT_IS_NOT_OK -#define ASSERT_IS_NOT_OK(expression) \ - ASSERT_NE(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) +#define ASSERT_IS_NOT_OK(expression) \ + ASSERT_NE(Status::OK(), xla::testing::internal_status::GetStatus(expression)) #endif // TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_ diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 0571ff50554c5d2291198ceddc085e3c21a9f145..4883380be1f8a291bda829dff713de549ba58c65 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -87,12 +87,12 @@ cc_library( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:error_spec", + "//tensorflow/compiler/xla:literal_comparison", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -152,7 +152,6 @@ tf_cc_binary( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", @@ -188,8 +187,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -288,8 +285,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -313,7 +308,6 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -335,7 +329,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -378,7 +371,6 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -398,7 +390,6 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -422,8 +413,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", @@ -450,8 +439,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -472,7 +459,6 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -491,7 +477,6 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -528,7 +513,6 @@ xla_test( tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", @@ -552,7 +536,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -572,8 +555,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -598,8 +579,6 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -626,7 +605,6 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -697,7 +675,6 @@ xla_test( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -741,7 +718,6 @@ xla_test( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -766,7 +742,6 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -790,7 +765,6 @@ xla_test( "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -843,7 +817,6 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -868,7 +841,6 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -930,8 +902,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", @@ -960,8 +930,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -1002,7 +970,6 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1055,8 +1022,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1078,7 +1043,6 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -1108,8 +1072,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", @@ -1240,8 +1202,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -1281,7 +1241,6 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:reference_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1304,7 +1263,6 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1344,7 +1302,6 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1362,7 +1319,6 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1388,8 +1344,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1411,7 +1365,6 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1483,8 +1436,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", @@ -1532,7 +1483,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1545,6 +1495,30 @@ xla_test( ], ) +xla_test( + name = "cross_replica_sum_test", + srcs = ["cross_replica_sum_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", + ], +) + xla_test( name = "bitcast_convert_test", srcs = ["bitcast_convert_test.cc"], @@ -1574,8 +1548,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -1596,7 +1568,6 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1620,8 +1591,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1642,7 +1611,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -1661,7 +1629,6 @@ xla_test( srcs = ["execution_profile_test.cc"], deps = [ ":client_library_test_base", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1676,7 +1643,6 @@ xla_test( args = ["--xla_hlo_profile"], deps = [ ":client_library_test_base", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1782,8 +1748,6 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1811,8 +1775,6 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1850,8 +1812,6 @@ xla_test( deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1867,7 +1827,10 @@ xla_test( xla_test( name = "local_client_execute_test", + # TODO(b/79375911): Test times out in LLVM at normal size. + size = "large", srcs = ["local_client_execute_test.cc"], + shard_count = 30, tags = ["optonly"], deps = [ "//tensorflow/compiler/xla:literal_util", @@ -1877,8 +1840,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1946,8 +1907,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -2048,7 +2007,6 @@ xla_test( ":local_client_test_base", ":test_utils", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:xla_internal_test_main", diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index e8a5efe796a9209307ecfa343b89f66ff2a34e0f..36a706496918ac8c15780473019e2a8d098ffa22 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -2225,6 +2225,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) { ComputeAndCompareR1(&builder, {32, 31, 27, 15, 9, 3, 0}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, ClzS64s) { + XlaBuilder builder(TestName()); + auto a = + builder.ConstantR1({0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1}); + builder.Clz(a); + + ComputeAndCompareR1(&builder, {64, 63, 32, 1, 0}, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { // a ------ (add) --------- (add) // / / diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index 4e65cf11f3f1a027e1adc5a89930caba28958fea..ca337e78840e77377719636cd4cf33af2578210d 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -37,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 6ebbf7191833ef85ee4a48cc96c0a3be38c71228..51b9f0d3e330e73f5d110f0a62f824179d5c7cf7 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*Literal::CreateR0(42.0), *result, - error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR0(42.0), *result, + error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { @@ -62,9 +62,9 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, - error_spec_); + error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { @@ -85,13 +85,13 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), - LiteralView::Create(*result, {0}), error_spec_); + LiteralSlice(*result, {0}), error_spec_)); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), - LiteralView::Create(*result, {1}), error_spec_); + LiteralSlice(*result, {1}), error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { @@ -106,9 +106,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( - *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, - error_spec_); + EXPECT_TRUE( + LiteralTestUtil::Near(*Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + *result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { @@ -125,9 +125,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( - *Literal::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, - error_spec_); + EXPECT_TRUE( + LiteralTestUtil::Near(*Literal::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), + *result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { @@ -142,10 +142,10 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), - *result, error_spec_); + *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { @@ -166,8 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { Array2D pz({{1, 2}, {1, 2}}); expected.FillWithPZ(pz); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { @@ -196,8 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { } expected.FillWithYX(yx); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { @@ -218,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(r4_array), *result, - error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR4FromArray4D(r4_array), + *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { @@ -238,8 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { Array4D expected(64, 64, 3, 3); expected.Fill(1.0f); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { @@ -260,8 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { Array4D expected(3, 3, 2, 2); expected.FillWithYX(to_broadcast); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { @@ -291,8 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index a43ca3d5ca2ba39ba9c16213e985e50bf39c0b7d..5fd33b50c94356839bbed58acd43b7d0286f4a7e 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index c09e7eaf2bb94d84d68604bff4fc97a8e8dfbc07..bf8ed4d9fb0bc61b86ef0b5872711a122a3d416b 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -178,8 +177,7 @@ void ClientLibraryTestBase::ComputeAndCompareLiteral( error, shape_with_layout)); } -tensorflow::Status -ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const std::function arguments, const std::function choose; - choose = [&, this](int64 index) -> tensorflow::Status { + std::function choose; + choose = [&, this](int64 index) -> Status { if (index < arguments.size()) { // Try out all layouts for the operand. TF_ASSIGN_OR_RETURN(auto literal, @@ -230,7 +227,7 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( TF_RETURN_IF_ERROR(choose(index + 1)); arguments_with_layout.pop_back(); layout_strings.pop_back(); - return tensorflow::Status::OK(); + return Status::OK(); } std::vector minor_to_major(ShapeUtil::Rank(literal->shape())); @@ -248,7 +245,7 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( layout_strings.pop_back(); } while ( std::next_permutation(minor_to_major.begin(), minor_to_major.end())); - return tensorflow::Status::OK(); + return Status::OK(); } // Every argument has an assigned layout. @@ -263,13 +260,13 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( tensorflow::strings::StrAppend(&error_message, str, " "); } verify_output(*actual, error_message); - return tensorflow::Status::OK(); + return Status::OK(); }; return choose(0); } -tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, const Shape* shape_with_layout) { @@ -297,7 +294,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( std::unique_ptr converted_expected; Shape layout_shape; if (use_bfloat16_) { - converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); + converted_expected = Literal::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; @@ -311,7 +308,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } } auto expect_equal = [&](const Literal& actual, const string& error_message) { - LiteralTestUtil::ExpectEqual(*expected_ptr, actual, error_message); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)) << error_message; }; if (execution_options_.debug_options().xla_test_all_output_layouts()) { return ComputeAndCompareLiteralWithAllOutputLayouts( @@ -323,11 +320,11 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - LiteralTestUtil::ExpectEqual(*expected_ptr, *actual); - return tensorflow::Status::OK(); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual)); + return Status::OK(); } -tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, ErrorSpec error, const Shape* shape_with_layout) { @@ -349,7 +346,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( std::unique_ptr converted_expected; Shape layout_shape; if (use_bfloat16_) { - converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); + converted_expected = Literal::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; @@ -363,7 +360,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } } auto expect_near = [&](const Literal& actual, const string& error_message) { - LiteralTestUtil::ExpectNear(*expected_ptr, actual, error, error_message); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error)) + << error_message; }; if (execution_options_.debug_options().xla_test_all_output_layouts()) { return ComputeAndCompareLiteralWithAllOutputLayouts( @@ -375,8 +373,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - LiteralTestUtil::ExpectNear(*expected_ptr, *actual, error); - return tensorflow::Status::OK(); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error)); + return Status::OK(); } void ClientLibraryTestBase::ComputeAndCompareR1U8( @@ -407,7 +405,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(expected, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual)); } void ClientLibraryTestBase::ComputeAndCompareTuple( @@ -419,7 +417,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(expected, *actual, error); + EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error)); } void ClientLibraryTestBase::ComputeAndCompare( @@ -431,7 +429,7 @@ void ClientLibraryTestBase::ComputeAndCompare( } std::unique_ptr reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(*reference, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result)); } void ClientLibraryTestBase::ComputeAndCompare( @@ -444,7 +442,7 @@ void ClientLibraryTestBase::ComputeAndCompare( } std::unique_ptr reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*reference, *result, error); + EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error)); } StatusOr, std::unique_ptr>> @@ -562,7 +560,36 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder) { return builder->ConstantLiteral( - use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); + use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal); +} + +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, + const Literal& literal, + const string& name, + XlaBuilder* builder, + XlaOp* data_handle) { + return CreateParameterAndTransferLiteral(parameter_number, literal, name, + nullptr, builder, data_handle); +} + +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + const DeviceHandle* device_handle, XlaBuilder* builder, + XlaOp* data_handle) { + const Literal* param_literal = &literal; + std::unique_ptr converted_literal; + if (use_bfloat16_) { + converted_literal = Literal::ConvertF32ToBF16(literal); + param_literal = converted_literal.get(); + } + std::unique_ptr data = + client_->TransferToServer(*param_literal, device_handle) + .ConsumeValueOrDie(); + *data_handle = + builder->Parameter(parameter_number, param_literal->shape(), name); + return data; } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index e58979a3035dd5823be7180aeb510fa3e15f51e2..0499fec5898a42affa0e0a712dee10187355c13e 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -188,11 +188,11 @@ class ClientLibraryTestBase : public ::testing::Test { const Shape* shape_with_layout = nullptr); // ComputeAndCompare variant which returns an error status. - tensorflow::Status ComputeAndCompareLiteralWithStatus( + Status ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout = nullptr); - tensorflow::Status ComputeAndCompareLiteralWithStatus( + Status ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); @@ -378,12 +378,12 @@ class ClientLibraryTestBase : public ::testing::Test { ExecutionOptions execution_options_; private: - tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts( + Status ComputeAndCompareLiteralWithAllOutputLayouts( const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const std::function& verify_output); - tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts( + Status ComputeAndCompareLiteralWithAllInputLayouts( const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const std::function ClientLibraryTestBase::CreateR0Parameter( XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR0(value); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -555,7 +555,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR1(values); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -569,7 +569,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR2FromArray2D(array_2d); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -583,7 +583,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR3FromArray3D(array_3d); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -616,35 +616,6 @@ std::unique_ptr> ClientLibraryTestBase::CreatePseudorandomR2( return result; } -std::unique_ptr -ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, - const Literal& literal, - const string& name, - XlaBuilder* builder, - XlaOp* data_handle) { - return CreateParameterAndTransferLiteral(parameter_number, literal, name, - nullptr, builder, data_handle); -} - -std::unique_ptr -ClientLibraryTestBase::CreateParameterAndTransferLiteral( - int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, XlaBuilder* builder, - XlaOp* data_handle) { - const Literal* param_literal = &literal; - std::unique_ptr converted_literal; - if (use_bfloat16_) { - converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); - param_literal = converted_literal.get(); - } - std::unique_ptr data = - client_->TransferToServer(*param_literal, device_handle) - .ConsumeValueOrDie(); - *data_handle = - builder->Parameter(parameter_number, param_literal->shape(), name); - return data; -} - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 0b425b93bb144e395baef2bcf074fd6e7991630b..08671cf62445826649b5c97003f998ae98a59d97 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -62,9 +62,9 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { TF_ASSERT_OK_AND_ASSIGN( auto computed, client_->Transfer(*data, &expected_literal->shape())); - LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(), - computed->shape()); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); + ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( + expected_literal->shape(), computed->shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } } @@ -91,9 +91,9 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { auto result, client_->ExecuteAndTransfer(computation, {}, &execution_options)); LiteralTestUtil::ExpectR2Equal({{1, 2}, {3, 4}}, - LiteralView::Create(*result, {0})); + LiteralSlice(*result, {0})); LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, - LiteralView::Create(*result, {1})); + LiteralSlice(*result, {1})); EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); @@ -142,7 +142,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { auto result_literal, client_->Transfer(*results[0], &expected_result->shape())); - LiteralTestUtil::ExpectEqual(*expected_result, *result_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index ecce599a8a3bd588c11d6bb9ba461b5a917197db..50a006964869b3e5dce431d441f7cd81af9df910 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" @@ -50,8 +49,8 @@ class CompilationCacheTest : public ClientLibraryTestBase { /*execution_options=*/&execution_options_, &execution_profile) .ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*Literal::CreateR0(expected_result), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR0(expected_result), *result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -67,8 +66,8 @@ class CompilationCacheTest : public ClientLibraryTestBase { .ConsumeValueOrDie(); std::unique_ptr result = client_->Transfer(*data_handle).ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*Literal::CreateR2(expected_result), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR2(expected_result), *result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index bf4b8fb0bcf229b4e8649b3920dcba1ae0579831..ba22530f1cfee56337f862c25122d399dbf0f1e4 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -208,7 +208,7 @@ TEST_F(ComputeConstantTest, NonScalarAdd) { ComputeConstantLiteral(client, computation, &b)); std::unique_ptr expected_literal = Literal::CreateR1({4, 6}); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } @@ -222,7 +222,7 @@ TEST_F(ComputeConstantTest, IntegerDivide) { TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); std::unique_ptr expected_literal = Literal::CreateR0(5); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } @@ -244,9 +244,9 @@ XLA_TEST_F(ComputeConstantTest, Layout) { std::unique_ptr expected_literal = Literal::CreateR2WithLayout({{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); - LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(), - computed->shape()); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); + ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( + expected_literal->shape(), computed->shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } } diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 4743673561a665ca8670a56bf15d85a74073e472..916ffadbc798ec0dd016f45b0bc4c36233455ee7 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -21,13 +21,11 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -169,9 +167,9 @@ TEST_F(ConstantsTest, DISABLED_TupleConstant) { ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); LiteralTestUtil::ExpectR2Near( - {{1.0}, {2.0}}, LiteralView::Create(*result, {0}), error_spec_); + {{1.0}, {2.0}}, LiteralSlice(*result, {0}), error_spec_); LiteralTestUtil::ExpectR1Near( - {2.0, 42.0}, LiteralView::Create(*result, {1}), error_spec_); + {2.0, 42.0}, LiteralSlice(*result, {1}), error_spec_); } } // namespace diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 50d6e25d868c4964ff35023b43a3734ed115bbb8..fea850dc135e33fe098aa755c6fdd93319cd2837 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 155fbacf58d81cff27939c142c8f30158cef4e00..2b3390ca98cb2922410d451c06811aa9d4ff8c0b 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -49,7 +49,7 @@ class CopyOpTest : public HloTestBase { module->AddEntryComputation(std::move(computation)); std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectEqual(literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result)); } void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3); @@ -253,7 +253,7 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) { auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape) .ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(*empty, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b15988776513a60c9e5c85d4780912106db98e75 --- /dev/null +++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace { + +class TrivialCrossReplicaSumTest : public HloTestBase {}; + +// Currently the CPU and GPU backends only support CrossReplicaSum with one +// replica. But we can at least check this. + +XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { + const char* module_str = R"( + HloModule test + ENTRY test_computation { + p = f32[3] parameter(0) + ROOT crs = f32[3] cross-replica-sum(p) + })"; + auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto literal = Literal::CreateR1({1, 2, 3}); + EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()})); +} + +XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { + const char* module_str = R"( + HloModule test + ENTRY test_computation { + p0 = f32[3] parameter(0) + p1 = f32[2] parameter(1) + ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1) + })"; + auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto literal0 = Literal::CreateR1({1, 2, 3}); + auto literal1 = Literal::CreateR1({10, 20}); + EXPECT_EQ( + *Literal::MakeTuple({literal0.get(), literal1.get()}), + *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()})); +} + +// On the GPU backend, constants get special handling. Someone might pass a +// constant to CRS to e.g. count the number of replicas -- we need to make sure +// it works. +XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { + const char* module_str = R"( + HloModule test + ENTRY test_computation { + p0 = f32[3] parameter(0) + p1 = f32[2] constant({10, 20}) + ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1) + })"; + auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto literal0 = Literal::CreateR1({1, 2, 3}); + auto literal1 = Literal::CreateR1({10, 20}); + EXPECT_EQ(*Literal::MakeTuple({literal0.get(), literal1.get()}), + *ExecuteAndTransfer(std::move(module), {literal0.get()})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index c76e5aabf4b8a3463b2971654d0a6cf0dd594626..bfe688e20d182d581c3e3b545ac2289413deef7c 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index d0ada2474830390e50a90c4c41aa42166d6e8ea5..12789fe66530fe03eb33316eda652336f29971ab 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 6b3efba4f80e45d230d3df9274d0fd40c6fb8c42..0fd846cef8095a857dd7b2c12d8afdf409e2bd66 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -61,7 +61,7 @@ using TypesF16F32F64CF64 = ::testing::Types; #endif // Check that we can safely pass an input tuple's elements to a dot operation. -TEST_F(DotOperationTest, DotOfInputTupleElem) { +XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { XlaBuilder builder(TestName()); XlaOp param; @@ -798,5 +798,250 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, this->error_spec_); } +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) { + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D expected({{96.0, 105.0, 114.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) { + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D expected({{105.0}, {105.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( + DotOfGatherOptimizationWithConstRHSReverseMM)))) { + std::unique_ptr> constant_lhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D expected({{105.0, 105.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( + DotOfGatherOptimizationWithConstLHSReverseMM)))) { + std::unique_ptr> constant_lhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D expected({{96.0}, {105.0}, {114.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSRows)))) { + std::unique_ptr> constant_lhs_array( + new Array2D({{1.0, 2.0}, + {3.0, 4.0}, + {5.0, 6.0}, + {6.0, 5.0}, + {4.0, 3.0}, + {2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D expected({{126.0, 129.0, 132.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSRows)))) { + std::unique_ptr> constant_lhs_array( + new Array2D({{1.0, 2.0}, + {3.0, 4.0}, + {5.0, 6.0}, + {6.0, 5.0}, + {4.0, 3.0}, + {2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D expected({{129.0}, {129.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSCols)))) { + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0, 9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D expected({{56.0, 168.0, 91.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSCols)))) { + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0, 9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D expected({{168.0}, {168.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index bfb83faf5222b8ca5ceceebf7f2f976ec803245e..49f3a10d227f2f9edfe76405ba13498fe822f8d8 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -53,9 +53,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void TestR1Wrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {6, 7, 0, 1}); + void TestR1OOB() { + // Slice at dimension boundaries, but with out of bounds indices. + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {4, 5, 6, 7}); } template @@ -78,10 +78,10 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void TestR2Wrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. + void TestR2OOB() { + // Slice at dimension boundaries, but with out of bounds indices. RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {3, 3}, - {{5, 6, 4}, {8, 9, 7}, {2, 3, 1}}); + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); } template @@ -106,11 +106,11 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void TestR3Wrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. + void TestR3OOB() { + // Slice at dimension boundaries, but with out of bounds indices. RunR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1}, - {2, 1, 2}, {{{6, 5}}, {{12, 11}}}); + {2, 1, 2}, {{{5, 6}}, {{11, 12}}}); } template @@ -199,19 +199,19 @@ class DynamicSliceTest : public ClientLibraryTestBase { XLA_TEST_F(DynamicSliceTest, Int32R1BF16) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1(); } -XLA_TEST_F(DynamicSliceTest, Int32R1Wrap) { TestR1Wrap(); } +XLA_TEST_F(DynamicSliceTest, Int32R1OOB) { TestR1OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1(); } XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2(); } -XLA_TEST_F(DynamicSliceTest, Int32R2Wrap) { TestR2Wrap(); } +XLA_TEST_F(DynamicSliceTest, Int32R2OOB) { TestR2OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3(); } XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3(); } -XLA_TEST_F(DynamicSliceTest, Int32R3Wrap) { TestR3Wrap(); } +XLA_TEST_F(DynamicSliceTest, Int32R3OOB) { TestR3OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3(); } XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } @@ -332,17 +332,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } template - void TestWrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. + void TestOOB() { + // // Slice at dimension boundaries, but with out of bounds indices. RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {6}, - {10, 1, 2, 3, 4, 5, 8, 9}); + {0, 1, 2, 3, 4, 8, 9, 10}); // R2 Shape: [3, 3] RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 2}, - {{1, 2, 3}, {4, 5, 6}, {11, 8, 10}}); + {{1, 2, 3}, {4, 5, 6}, {7, 10, 11}}); // R3 Shape: [2, 3, 2] RunR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}}, - {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 15}, {9, 10}, {11, 13}}}); + {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}}); } template @@ -476,20 +476,19 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { Array3D input_values(kSeq, kBatch, kDim); Array3D update_values(size, kBatch, kDim); Array3D expected_values(kSeq, kBatch, kDim); + index = std::min(std::max(0, index), kSeq - size); input_values.FillIota(static_cast(0)); T value = static_cast(10); update_values.FillIota(static_cast(value)); // TODO(b/34128753) Expected values may vary depending on backend when - // the update wraps. According to documentation, the results are technically - // implementation specific where the update is out of bounds, and hence - // we don't really know what to pass into ComputeAndCompareR3. + // the indices are out of bounds. expected_values.FillIota(static_cast(0)); for (int i = 0; i < size; i++) { for (int j = 0; j < kBatch; j++) { for (int k = 0; k < kDim; k++) { - expected_values((index + i) % kSeq, j, k) = value++; + expected_values(index + i, j, k) = value++; } } } @@ -547,12 +546,10 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int32WrapBF16) { - TestWrap(); -} -XLA_TEST_F(DynamicUpdateSliceTest, Int32Wrap) { TestWrap(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int64Wrap) { TestWrap(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt64Wrap) { TestWrap(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32OOBBF16) { TestOOB(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32OOB) { TestOOB(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int64OOB) { TestOOB(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt64OOB) { TestOOB(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) { // Slice at dimension start. @@ -615,37 +612,37 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R3Pred) { // Tests for simple R3 case where the update is contiguous (i.e. the minor // two dimensions are not sliced). XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) { - // Single element, no wrap. + // Single element, index in-bounds std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElementBF16) { - // Single element, no wrap. + // Single element, index in-bounds std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) { - // Multiple element, no wrap. + // Multiples element, index in-bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElementsBF16) { - // Multiple element, no wrap. + // Multiples element, index in-bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrapping) { - // Multiple element, wrapping. +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOB) { + // Multiple element, index out of bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrappingBF16) { - // Multiple element, wrapping. +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOBBF16) { + // Multiple element, index out of bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); } diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index b947f8208a5fa3f5a396ebc7a234afbf7ac3d900..e6f79b5ac55dddfbb213a36cadbee53bc9443d9d 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -118,9 +118,9 @@ class FusionTest : public HloTestBase { auto expected = Literal::CreateR2FromArray2D(answer_data); auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); if (primitive_util::IsFloatingPointType(prim_type)) { - LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4))); } else { - LiteralTestUtil::ExpectEqual(*expected, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); } } @@ -221,9 +221,9 @@ XLA_TEST_F(FusionTest, Test) { const4, reshape3, add2, const1, const0}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear(*Literal::CreateR2({{0.5}, {2.72}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), - ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR2({{0.5}, {2.72}}), + *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } // Test whether we emit appropriate code for parameters of fusion instructions. @@ -247,9 +247,9 @@ XLA_TEST_F(FusionTest, Parameter) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear(*Literal::CreateR2({{-1.0, 0.0, 1.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), - ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR2({{-1.0, 0.0, 1.0}}), + *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, RandomizedParallelPartition) { @@ -307,9 +307,9 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)); + *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, ReshapeToScalar) { @@ -322,8 +322,9 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(5), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(5), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { @@ -336,9 +337,9 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { @@ -351,9 +352,9 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by1by1_) { @@ -366,8 +367,9 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(7), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__1by1by1) { @@ -380,8 +382,9 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR3({{{7}}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR3({{{7}}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__) { @@ -394,8 +397,9 @@ XLA_TEST_F(FusionTest, Reshape__) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(7), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { @@ -408,9 +412,9 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_2by3) { @@ -423,9 +427,9 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 4}, {2, 5}, {3, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_3by3) { @@ -438,9 +442,9 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reverse) { @@ -454,8 +458,9 @@ XLA_TEST_F(FusionTest, Reverse) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({3, 2, 1}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({3, 2, 1}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReverseNegate) { @@ -471,8 +476,9 @@ XLA_TEST_F(FusionTest, ReverseNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reverse1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-3, -2, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-3, -2, -1}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, BroadcastNegate) { @@ -488,8 +494,9 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, broadcast1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-1, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-1, -1}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, SliceNegate) { @@ -505,8 +512,9 @@ XLA_TEST_F(FusionTest, SliceNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, slice1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-1, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-1, -3}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DynamicSliceNegate) { @@ -526,8 +534,9 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { /*instructions_to_fuse=*/{negate3, dynamic_slice2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-2, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-2, -3}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReshapeNegate) { @@ -543,8 +552,9 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR2({{-1, -2}, {-3, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-1, -2}, {-3, -4}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } // TODO(b/64070202): Investigate failure. @@ -561,8 +571,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR2({{-1, -3}, {-2, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-1, -3}, {-2, -4}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } std::unique_ptr MakeReduceTestComputation() { @@ -591,8 +602,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(15), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(15), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { @@ -612,8 +624,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(-15), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(-15), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { @@ -661,9 +674,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{462, 2145}, {24871, 62491}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } // When a constant (or other op) which has multiple users is imported @@ -697,8 +710,9 @@ XLA_TEST_F(FusionTest, SharedConstant) { // fused instruction contains the constant(2), the parameter, and 4 adds EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({8}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({8}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D(HloOpcode::kAdd); } diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 130456e61ca8a217e903d2ddecc487f29a098ce1..4854c649c15f2ab89bd3b343abd248be6e227c60 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -629,8 +629,8 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { client_->ExecuteParallel(computation_instances)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, client_->Transfer(*(result_data[0]))); - LiteralTestUtil::ExpectEqual( - *result_literal, *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result_literal, *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}))); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 12598579c7032e954c4a4875ab8e6475b112f5ae..36e19e6507fa3b6f4a21949583f92716d2f44333 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -94,18 +94,15 @@ HloTestBase::HloTestBase(se::Platform* test_platform, /* static */ std::unique_ptr HloTestBase::CreateNewModule(const string& name) { - HloModuleConfig config; - auto debug_options = HloTestBase::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_max_kernel_unroll_factor(1); - config.set_debug_options(debug_options); - - return MakeUnique(name, VersionedComputationHandle(), config); + return MakeUnique(name, VersionedComputationHandle(), + GetModuleConfigForTest()); } /*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); + debug_options.set_xla_gpu_max_kernel_unroll_factor(1); return debug_options; } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 9539ae06801628baedaea69024b7760ebefa6e3a..eb3a2ea76a667a2afa2562f01d28f34384b84a21 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -93,6 +93,13 @@ class HloTestBase : public ::testing::Test { // DebugOptions, e.g. when creating a module from a string or a file. static DebugOptions GetDebugOptionsForTest(); + // Gets an HloModuleConfig with options appropriate for tests. + static HloModuleConfig GetModuleConfigForTest() { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + return config; + } + // Executes the given module and return the result as a Literal. StatusOr> Execute( std::unique_ptr module, diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index c28f79ae386670ca80d603a42f6629dfd30e0bc9..cde1dcd9cd10c86107f495a92be42b57bf6a085b 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -15,978 +15,93 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include -#include -#include - -#include "tensorflow/compiler/xla/index_util.h" -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/compiler/xla/literal_comparison.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" namespace xla { -using ::tensorflow::strings::Appendf; -using ::tensorflow::strings::Printf; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - -/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes( - const Shape& expected, const Shape& actual) { - if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { - return ::testing::AssertionFailure() - << "tupleness-mismatch! want: " << ShapeUtil::HumanString(expected) - << " got: " << ShapeUtil::HumanString(actual); - } - if (ShapeUtil::IsTuple(expected)) { - if (ShapeUtil::TupleElementCount(expected) != - ShapeUtil::TupleElementCount(actual)) { - return ::testing::AssertionFailure() - << "want tuple element count: " - << ShapeUtil::TupleElementCount(expected) - << " got tuple element count: " - << ShapeUtil::TupleElementCount(actual); - } - for (int i = 0; i < expected.tuple_shapes_size(); ++i) { - ::testing::AssertionResult result = - EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)) - << "mismatch in tuple index " << i; - if (!result) { - return result; - } - } - } else { - if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { - return ::testing::AssertionFailure() - << "want rank of: " << ShapeUtil::HumanString(expected) - << " got rank of: " << ShapeUtil::HumanString(actual); - } - if (expected.element_type() != actual.element_type()) { - return ::testing::AssertionFailure() - << PrimitiveType_Name(expected.element_type()) << " vs " - << PrimitiveType_Name(actual.element_type()); - } - if (expected.dimensions_size() != actual.dimensions_size()) { - return ::testing::AssertionFailure() - << "want dimensions_size " << expected.dimensions_size() - << " got dimensions_size " << actual.dimensions_size(); - } - for (int i = 0; i < expected.dimensions_size(); ++i) { - if (expected.dimensions(i) != actual.dimensions(i)) { - return ::testing::AssertionFailure() - << "mismatch in dimension #" << i - << " expected: " << ShapeUtil::HumanString(expected) - << " actual: " << ShapeUtil::HumanString(actual); - } - } - } - return ::testing::AssertionSuccess(); -} - -/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected, - const Shape& actual) { - ASSERT_TRUE(EqualShapes(expected, actual)); -} - -/* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts( - const Shape& expected, const Shape& actual) { - ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString()); -} - -namespace { - -// Return a literal with all arrays of type FromNativeT converted to type -// ToNativeT in the given literal. -template -std::unique_ptr ConvertType(const Literal& literal) { - // First construct shape of the result. - Shape result_shape(literal.shape()); - ShapeUtil::ForEachMutableSubshape( - &result_shape, [](Shape* subshape, const ShapeIndex&) { - if (subshape->element_type() == - primitive_util::NativeToPrimitiveType()) { - subshape->set_element_type( - primitive_util::NativeToPrimitiveType()); - } - }); - auto result = MakeUnique(result_shape); - - // Then copy over the data from 'literal' converting FromNativeT values to - // ToNativeT values as necessary. - ShapeUtil::ForEachSubshape( - literal.shape(), - [&](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsArray(subshape)) { - if (subshape.element_type() == - primitive_util::NativeToPrimitiveType()) { - auto src = literal.data(shape_index); - auto dest = result->data(shape_index); - for (int64 i = 0; i < src.size(); ++i) { - dest[i] = static_cast(src[i]); - } - } else { - TF_CHECK_OK(result->CopyFrom(literal, - /*dest_shape_index=*/shape_index, - /*src_shape_index=*/shape_index)); - } - } - }); - return result; -} - -} // namespace - -/* static */ std::unique_ptr LiteralTestUtil::ConvertBF16ToF32( - const Literal& literal) { - return ConvertType(literal); -} - -/* static */ std::unique_ptr LiteralTestUtil::ConvertF32ToBF16( - const Literal& literal) { - return ConvertType(literal); -} - namespace { -string Hostname() { - char hostname[1024]; - gethostname(hostname, sizeof hostname); - hostname[sizeof hostname - 1] = 0; - return string(hostname); -} - -// Helper function for comparing a floating point type, FloatT, bitwise equal -// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT -// -- on miscompare, a nice error message is given in the AssertionFailure. -template -::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { - auto ulhs = tensorflow::bit_cast(lhs); - auto urhs = tensorflow::bit_cast(rhs); - auto lhs_double = static_cast(lhs); - auto rhs_double = static_cast(rhs); - if (ulhs != urhs) { - return ::testing::AssertionFailure() << Printf( - "floating values are not bitwise-equal; and equality testing " - "was requested: %s=%g=%a vs %s=%g=%a", - StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, - lhs_double, StrCat(tensorflow::strings::Hex(urhs)).c_str(), - rhs_double, rhs_double); - } - return ::testing::AssertionSuccess(); -} - -// Templated comparator that specializes for float equality comparison with the -// bitwise helper above (this is the un-specialized fallback, to just use the -// default gunit implementation). -template -::testing::AssertionResult CompareEqual(NativeT lhs, NativeT rhs) { - if (lhs == rhs) { +// Writes the given literal to a file in the test temporary directory. +void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) { + auto get_hostname = [] { + char hostname[1024]; + gethostname(hostname, sizeof hostname); + hostname[sizeof hostname - 1] = 0; + return string(hostname); + }; + int64 now_usec = tensorflow::Env::Default()->NowMicros(); + string filename = tensorflow::io::JoinPath( + tensorflow::testing::TmpDir(), + tensorflow::strings::Printf("tempfile-%s-%llx-%s", get_hostname().c_str(), + now_usec, name.c_str())); + TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), filename, + literal.ToProto())); + LOG(ERROR) << "wrote to " << name << " file: " << filename; +} + +// Callback helper that dumps literals to temporary files in the event of a +// miscomparison. +void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, + const LiteralSlice& mismatches) { + LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) << " " + << literal_comparison::ToStringTruncated(expected); + LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) << " " + << literal_comparison::ToStringTruncated(actual); + LOG(INFO) << "Dumping literals to temp files..."; + WriteLiteralToTempFile(expected, "expected"); + WriteLiteralToTempFile(actual, "actual"); + WriteLiteralToTempFile(mismatches, "mismatches"); +} + +::testing::AssertionResult StatusToAssertion(const Status& s) { + if (s.ok()) { return ::testing::AssertionSuccess(); } - ::testing::Message msg; - msg << "Expected equality of these values:"; - msg << "\n " << lhs; - msg << "\n " << rhs; - - return ::testing::AssertionFailure() << msg; -} - -// Specializations for floating types that do bitwise comparisons when equality -// comparison is requested. -template <> -::testing::AssertionResult CompareEqual(bfloat16 lhs, bfloat16 rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); -} -template <> -::testing::AssertionResult CompareEqual(Eigen::half lhs, - Eigen::half rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); -} -template <> -::testing::AssertionResult CompareEqual(float lhs, float rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); -} -template <> -::testing::AssertionResult CompareEqual(double lhs, double rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); -} -template <> -::testing::AssertionResult CompareEqual(complex64 lhs, - complex64 rhs) { - auto res = CompareEqual(lhs.real(), rhs.real()); - if (!res) { - return res; - } - return CompareEqual(lhs.imag(), rhs.imag()); -} - -// A recursive function which iterates through every index of expected and -// actual literal and compares their values elementwise. Returns true if all -// elements are equal. -template -bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, - tensorflow::gtl::MutableArraySlice multi_index, - int64 dimension) { - if (dimension == expected.shape().dimensions_size()) { - NativeT expected_value = expected.Get(multi_index); - NativeT actual_value = actual.Get(multi_index); - ::testing::AssertionResult result = - CompareEqual(expected_value, actual_value); - return result; // Defines implicit coersion to bool. - } - - bool all_match = true; - for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { - multi_index[dimension] = i; - all_match = all_match && ExpectLiteralsEqual( - expected, actual, multi_index, dimension + 1); - } - return all_match; + return ::testing::AssertionFailure() << s.error_message(); } } // namespace -/* static */ void LiteralTestUtil::ExpectEqual(const Literal& expected, - const Literal& actual, - const string& message) { - EXPECT_TRUE(Equal(expected, actual)) - << "expected:\n" - << expected.ToString() << "\n\tvs actual:\n" - << actual.ToString() - << (message.empty() ? "" : StrCat("\nmessage: ", message)); -} - -/* static */ void LiteralTestUtil::ExpectNotEqual(const Literal& expected, - const Literal& actual) { - EXPECT_FALSE(Equal(expected, actual)); -} - -/* static */ ::testing::AssertionResult LiteralTestUtil::Equal( - const Literal& expected, const Literal& actual) { - VLOG(1) << "expected:"; - XLA_VLOG_LINES(1, expected.ToString()); - VLOG(1) << "actual:"; - XLA_VLOG_LINES(1, actual.ToString()); - - AssertEqualShapes(expected.shape(), actual.shape()); - std::vector multi_index(expected.shape().dimensions_size(), 0); - bool match = false; - switch (expected.shape().element_type()) { - case PRED: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case U8: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case S32: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case S64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case U32: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case U64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case BF16: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case F16: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case F32: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case F64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case C64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case TUPLE: { - bool tuple_match = true; - for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { - SCOPED_TRACE(StrCat("Tuple index ", i, " in ", - ShapeUtil::HumanString(expected.shape()))); - - // Create LiteralViews of the expected and actual elements. - auto result = Equal(LiteralView::Create(expected, {i}), - LiteralView::Create(actual, {i})); - tuple_match = tuple_match ? !!result : false; - } - match = tuple_match; - break; - } - default: - LOG(FATAL) - << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " - << PrimitiveType_Name(expected.shape().element_type()); - } - ::testing::AssertionResult result = ::testing::AssertionSuccess(); - if (!match) { - result = ::testing::AssertionFailure() - << "expected: " << expected.ToString() - << "\nactual: " << actual.ToString(); - VLOG(1) << result.message(); - } - return result; -} - -namespace { - -// Gets the total element count. For tuples, this is not the count of tuple -// elements, but the sum of elements of each tuple element. -int64 RecursiveElementCount(const Shape& shape) { - if (ShapeUtil::IsTuple(shape)) { - const int64 tuple_elements = ShapeUtil::TupleElementCount(shape); - int64 total = 0; - for (int64 i = 0; i < tuple_elements; ++i) { - total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); - } - return total; - } else { - return ShapeUtil::ElementsIn(shape); - } -} - -// Calling ToString on a literal with over 100 million elements takes around -// 3 minutes. The utility of printing a literal with >1000 elements is -// questionable, especially when writing the Literal proto to disk is orders -// of magnitude faster. -string TruncateHugeLiteral(const Literal& literal) { - return RecursiveElementCount(literal.shape()) < 1000 - ? literal.ToString() - : "[TRUNCATED, Literal with more than 1000 values]"; +/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes( + const Shape& expected, const Shape& actual) { + return StatusToAssertion(literal_comparison::EqualShapes(expected, actual)); } -// Returns whether the actual and expected values are mismatched with respect to -// nans. 'relaxed_nans' is interpreted as in xla::ErrorSpec. -template -bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) { - if (relaxed_nans) { - return !std::isnan(expected) && std::isnan(actual); - } else { - return std::isnan(expected) != std::isnan(actual); +/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapesAndLayouts( + const Shape& expected, const Shape& actual) { + if (expected.ShortDebugString() != actual.ShortDebugString()) { + return ::testing::AssertionFailure() + << "want: " << expected.ShortDebugString() + << " got: " << actual.ShortDebugString(); } + return ::testing::AssertionSuccess(); } -template <> -bool NanMismatch(complex64 expected, complex64 actual, - bool relaxed_nans) { - return NanMismatch(expected.real(), actual.real(), relaxed_nans) || - NanMismatch(expected.imag(), actual.imag(), relaxed_nans); -} - -template <> -bool NanMismatch(half expected, half actual, bool relaxed_nans) { - return NanMismatch(static_cast(expected), - static_cast(actual), relaxed_nans); -} - -// Converts the given floating-point value to a string. -template -string FpValueToString(NativeT value) { - return Printf("%8.4g", static_cast(value)); -} - -template <> -string FpValueToString(complex64 value) { - return Printf("%8.4g + %8.4fi", value.real(), value.imag()); -} - -// Returns the absolute value of the given floating point value. This function -// is used instead of std::abs directly in order to allow type-dependent -// implementations for NearComparator. -template -float FpAbsoluteValue(NativeT value) { - return std::abs(value); -} - -template <> -float FpAbsoluteValue(bfloat16 value) { - return FpAbsoluteValue(static_cast(value)); -} - -template <> -float FpAbsoluteValue(half value) { - return FpAbsoluteValue(static_cast(value)); -} - -// Helper class for comparing floating-point literals within an error bound. -template -class NearComparator { - public: - // Compares the two array literals elementwise and returns an assertion - // result. The assertion result is successful if all actual and expected - // elements are within the given error bound. In case of error, the assertion - // result contains a detailed error message in case of failure. - static ::testing::AssertionResult Compare(const Literal& expected, - const Literal& actual, - ErrorSpec error, - bool detailed_message) { - NearComparator comparator(expected, actual, error, - detailed_message); - return comparator.Run(); - } - - private: - // Data structure encapsulating metadata about a single element mismatch. - struct Mismatch { - NativeT actual; - NativeT expected; - float rel_error; - float abs_error; - - // The linear index of the failure within the shape. This linear index is - // from the 'actual' literal. - int64 linear_index; - - bool operator<(const Mismatch& other) const { - return rel_error < other.rel_error; - } - - string ToString(const Shape& shape) const { - return Printf( - "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g", - FpValueToString(actual).c_str(), FpValueToString(expected).c_str(), - LiteralTestUtil::MultiIndexAsString( - IndexUtil::LinearIndexToMultidimensionalIndex(shape, - linear_index)) - .c_str(), - rel_error, abs_error); - } - }; - - explicit NearComparator(const Literal& expected, const Literal& actual, - ErrorSpec error, bool detailed_message) - : expected_(expected), - actual_(actual), - error_(error), - detailed_message_(detailed_message), - abs_value_buckets_(kAbsValueBucketBounds.size() - 1, {0, 0}), - abs_error_buckets_(kErrorBucketBounds.size(), 0), - rel_error_buckets_(kErrorBucketBounds.size(), 0) {} - - // Runs the comparison between expected and actual literals. - ::testing::AssertionResult Run() { - VLOG(1) << "expected:"; - XLA_VLOG_LINES(1, TruncateHugeLiteral(expected_)); - VLOG(1) << "actual:"; - XLA_VLOG_LINES(1, TruncateHugeLiteral(actual_)); - - // If the shapes mismatch, we simply fail the expectation instead of - // printing out data, as it's a type error rather than a value error. - ::testing::AssertionResult equal_shapes = - LiteralTestUtil::EqualShapes(expected_.shape(), actual_.shape()); - if (!equal_shapes) { - return equal_shapes; - } - if (!ShapeUtil::IsArray(expected_.shape())) { - return ::testing::AssertionFailure() << "Expected array shape"; - } - - mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED)); - mismatches_.PopulateWithValue(false); - - CompareLiterals(); - - if (num_mismatches_ == 0) { - return ::testing::AssertionSuccess(); - } else if (!VLOG_IS_ON(1)) { - LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected_.shape()) - << " " << TruncateHugeLiteral(expected_); - LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual_.shape()) - << " " << TruncateHugeLiteral(actual_); - LOG(INFO) << "Dumping literals to temp files..."; - WriteLiteralToTempFile(expected_, "expected"); - WriteLiteralToTempFile(actual_, "actual"); - WriteLiteralToTempFile(mismatches_, "mismatches"); - } - return ::testing::AssertionFailure() << ErrorMessage(); - } - - // Insert the given absolute value into the absolute value bucket vector. The - // bounds of the buckets are given by kAbsValueBucketBounds. - void UpdateAbsValueBucket(NativeT value, bool is_mismatch) { - // Adjust the bucket containing the absolute values of the 'actual' - // elements. - const float abs_value = FpAbsoluteValue(value); - for (int i = 0; i < abs_value_buckets_.size(); ++i) { - if (i == abs_value_buckets_.size() - 1 || - (abs_value >= kAbsValueBucketBounds[i] && - abs_value < kAbsValueBucketBounds[i + 1])) { - // The first value of the pair is the count of elements in the bucket, - // the second is the count of mismatches in the bucket. - abs_value_buckets_[i].first++; - if (is_mismatch) { - abs_value_buckets_[i].second++; - } - return; - } - } - } - - // Insert the given error into the given error bucket vector. - void UpdateErrorBucket( - float error, tensorflow::gtl::MutableArraySlice error_buckets) { - CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size()); - for (int i = 0; i < error_buckets.size(); ++i) { - if (error >= kErrorBucketBounds[i]) { - error_buckets[i]++; - } - } - } - - // Compares the two given elements from the expected and actual literals at - // the given literal_index and keeps track of various mismatch statistics. - void CompareValues(NativeT expected, NativeT actual, int64 linear_index) { - const bool is_nan_mismatch = - NanMismatch(expected, actual, error_.relaxed_nans); - float abs_error; - float rel_error; - if (actual == expected) { - abs_error = 0; - rel_error = 0; - } else if (is_nan_mismatch) { - num_nan_mismatches_++; - // A nan mismatch is considered to have infinite error. rel_error is used - // for sorting a std::set of the top mismatchs, and a nan value here will - // result in undefined behavior because nan's do not satisfy the strict - // weak ordering requirement of std containers. - abs_error = std::numeric_limits::infinity(); - rel_error = std::numeric_limits::infinity(); - } else { - abs_error = FpAbsoluteValue(actual - expected); - rel_error = abs_error / FpAbsoluteValue(expected); - } - const bool is_abs_mismatch = abs_error > error_.abs; - const bool is_rel_mismatch = rel_error > error_.rel; - const bool is_mismatch = - is_nan_mismatch || (is_abs_mismatch && is_rel_mismatch); - - // Update the error of the relative bucket only if the *absolute* error - // bound is exceeded and vice versa. - if (is_abs_mismatch) { - num_abs_mismatches_++; - UpdateErrorBucket(rel_error, &rel_error_buckets_); - } - if (is_rel_mismatch) { - num_rel_mismatches_++; - UpdateErrorBucket(abs_error, &abs_error_buckets_); - } - - UpdateAbsValueBucket(actual, is_mismatch); - - if (!is_mismatch) { - return; - } - - num_mismatches_++; - - // Keep track of the kTopRelativeErrorCount relative error mismatches. - if (top_rel_mismatches_.size() < kTopRelativeErrorCount || - rel_error > top_rel_mismatches_.begin()->rel_error) { - Mismatch mismatch = {actual, expected, rel_error, abs_error, - linear_index}; - top_rel_mismatches_.insert(mismatch); - if (top_rel_mismatches_.size() > kTopRelativeErrorCount) { - top_rel_mismatches_.erase(top_rel_mismatches_.begin()); - } - } - - mismatches_.data()[linear_index] = true; - } - - // Compares the two literals elementwise. - void CompareLiterals() { - // Fast path optimization for the case were layouts match. - if (LayoutUtil::Equal(actual_.shape().layout(), - expected_.shape().layout())) { - tensorflow::gtl::ArraySlice expected_data = - expected_.data(); - tensorflow::gtl::ArraySlice actual_data = - actual_.data(); - const int64 len = expected_data.size(); - for (int64 i = 0; i < len; ++i) { - CompareValues(expected_data[i], actual_data[i], i); - } - return; - } - std::vector multi_index(ShapeUtil::Rank(actual_.shape()), 0); - CompareLiteralsSlow(0, &multi_index); - } - - // Slow path for CompareLiterals when 'actual' and 'expected' literals have - // different layouts. In this case, multidimensional indices are constructed - // and indexed for each element. - void CompareLiteralsSlow(int64 dimension, std::vector* multi_index) { - if (dimension == multi_index->size()) { - CompareValues(expected_.Get(*multi_index), - actual_.Get(*multi_index), - IndexUtil::MultidimensionalIndexToLinearIndex( - actual_.shape(), *multi_index)); - } else { - for (int64 i = 0; i < expected_.shape().dimensions(dimension); ++i) { - (*multi_index)[dimension] = i; - CompareLiteralsSlow(dimension + 1, multi_index); - } - } - } - - // Writes the given literal to a file in the test temporary directory. - void WriteLiteralToTempFile(const Literal& literal, const string& name) { - int64 now_usec = tensorflow::Env::Default()->NowMicros(); - string filename = tensorflow::io::JoinPath( - tensorflow::testing::TmpDir(), - Printf("tempfile-%s-%llx-%s", Hostname().c_str(), now_usec, - name.c_str())); - TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), - filename, literal.ToProto())); - LOG(ERROR) << "wrote to " << name << " file: " << filename; - } - - // Returns an error message string with a detailed breakdown of the - // mismatches. Called after calling Run(). - string ErrorMessage() { - string out; - int64 element_count = ShapeUtil::ElementsIn(actual_.shape()); - - auto percent_string = [](float a, float b) { - float pct = b == 0.0 ? 0.0 : 100.0 * a / b; - return Printf("%0.4f%%", pct); - }; - - Appendf(&out, - "\nMismatch count %lld (%s) in shape %s (%lld elements), abs bound " - "%g, rel bound %g\n", - num_mismatches_, - percent_string(num_mismatches_, element_count).c_str(), - ShapeUtil::HumanString(actual_.shape()).c_str(), - ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel); - if (num_nan_mismatches_ > 0) { - StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n"); - } - Appendf(&out, "Top relative error mismatches:\n"); - for (auto it = top_rel_mismatches_.rbegin(); - it != top_rel_mismatches_.rend(); ++it) { - StrAppend(&out, " ", it->ToString(actual_.shape()).c_str(), "\n"); - } - - if (!detailed_message_) { - return out; - } - - StrAppend(&out, "Absolute magnitude breakdown of actual values:\n"); - CHECK_EQ(abs_value_buckets_.size() + 1, kAbsValueBucketBounds.size()); - for (int i = 0; i < abs_value_buckets_.size(); ++i) { - const int64 bucket_size = abs_value_buckets_[i].first; - const int64 bucket_mismatches = abs_value_buckets_[i].second; - string mismatch_str = bucket_mismatches > 0 - ? Printf(", mismatches %lld", bucket_mismatches) - : ""; - Appendf(&out, " %-6g <= x < %-6g : %7lld (%9s)%s\n", - kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1], - bucket_size, percent_string(bucket_size, element_count).c_str(), - mismatch_str.c_str()); - } - - auto print_accum_buckets = [&](const string& header, int64 total, - tensorflow::gtl::ArraySlice buckets) { - StrAppend(&out, header, ":\n"); - Appendf(&out, " < %-6g : %7lld (%s)\n", kErrorBucketBounds[0], - total - buckets[0], - percent_string(total - buckets[0], total).c_str()); - CHECK_EQ(buckets.size(), kErrorBucketBounds.size()); - for (int i = 0; i < kErrorBucketBounds.size(); ++i) { - Appendf(&out, " >= %-6g : %7lld (%s)\n", kErrorBucketBounds[i], - buckets[i], percent_string(buckets[i], total).c_str()); - } - }; - Appendf(&out, "Elements exceeding abs error bound %g: %lld (%s)\n", - error_.abs, num_abs_mismatches_, - percent_string(num_abs_mismatches_, element_count).c_str()); - print_accum_buckets( - "Relative error breakdown of elements exceeding abs error bound", - num_abs_mismatches_, rel_error_buckets_); - Appendf(&out, "Elements exceeding rel error bound %g: %lld (%s)\n", - error_.rel, num_rel_mismatches_, - percent_string(num_rel_mismatches_, element_count).c_str()); - print_accum_buckets( - "Absolute error breakdown of elements exceeding rel error bound", - num_rel_mismatches_, abs_error_buckets_); - return out; - } - - // 'actual' and 'expected' literals being compared. - const Literal& expected_; - const Literal& actual_; - - // The error bounds of the comparison. - ErrorSpec error_; - - // Whether to include detailed breakdown of mismatches in the error message. - bool detailed_message_; - - // Number of element element mismatches encountered so far. - int64 num_mismatches_ = 0; - - // Number of elements with a nan mismatch. - int64 num_nan_mismatches_ = 0; - - // Number of elements which exceed the absolute/relative error bound. - int64 num_abs_mismatches_ = 0; - int64 num_rel_mismatches_ = 0; - - // A Literal containing which elements did not match in the expected and - // actual literals. mismatches_ contains PREDs and is of the same sizes as - // the comparison literals. - Literal mismatches_; - - // The number of mismatches to report in the output, sorted by relative error - // magnitude. - static constexpr int64 kTopRelativeErrorCount = 5; - - // The set of mismatches with the largest relative error. The size of this set - // is bounded by kTopRelativeErrorCount. - std::multiset top_rel_mismatches_; - - // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the - // bounds of these buckets. abs_value_buckets_ contains a pair for each - // bucket: the element count and failure count. - static constexpr std::array kAbsValueBucketBounds = { - 0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits::infinity()}; - std::vector> abs_value_buckets_; - - // Buckets for relative and absolute errors. The relative error buckets only - // contains those elements which exceed the *absolute* error bound, and vice - // versa. This makes it easy to see the effect of adjusting the relative (or - // absolute) error bound on the success of the comparison. kErrorBucketBounds - // are the lower bounds of the buckets in both vectors. The error buckets are - // a cumulative distribution so an error value may appear in more than one - // bucket. For example an error value of 0.003 may appear in the buckets - // bounded by 0.01, 0.1, and 1.0. - static constexpr std::array kErrorBucketBounds = {0.0001, 0.001, - 0.01, 0.1, 1}; - std::vector abs_error_buckets_; - std::vector rel_error_buckets_; -}; - -template -constexpr std::array NearComparator::kAbsValueBucketBounds; -template -constexpr std::array NearComparator::kErrorBucketBounds; - -// Helper function for comparing two literals for nearness. Handles tuple-shapes -// via recursion. shape_index is the ShapeIndex of expected (or actual) -// currently being compared. -::testing::AssertionResult NearHelper(const Literal& expected, - const Literal& actual, - const ErrorSpec& error, - bool detailed_message, - const ShapeIndex& shape_index) { - ::testing::AssertionResult err = - LiteralTestUtil::EqualShapes(expected.shape(), actual.shape()); - if (!err) { - return err; - } - - if (ShapeUtil::IsTuple(expected.shape())) { - for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { - const auto expected_element = LiteralView::Create(expected, {i}); - const auto actual_element = LiteralView::Create(actual, {i}); - ShapeIndex element_index = shape_index; - element_index.push_back(i); - ::testing::AssertionResult res = - NearHelper(expected_element, actual_element, error, detailed_message, - element_index); - if (!res) { - string err_message = - Printf("\nArray at shape index %s%s", - element_index.ToString().c_str(), res.message()); - if (err) { - err = ::testing::AssertionFailure() << err_message; - } else { - err << err_message; - } - } - } - if (!err && shape_index.empty()) { - // Emit a top-level error message containing the top-level shape in case - // of mismatch. - int64 total_elements = RecursiveElementCount(actual.shape()); - err = ::testing::AssertionFailure() - << Printf("\nMismatches in shape %s (%lld elements):\n%s", - ShapeUtil::HumanString(actual.shape()).c_str(), - total_elements, err.message()); - } - return err; - } - - if (ShapeUtil::ElementIsFloating(expected.shape()) || - ShapeUtil::ElementIsComplex(expected.shape())) { - switch (expected.shape().element_type()) { - case BF16: - return NearComparator::Compare(expected, actual, error, - detailed_message); - break; - case F16: - return NearComparator::Compare(expected, actual, error, - detailed_message); - break; - case F32: - return NearComparator::Compare(expected, actual, error, - detailed_message); - break; - case F64: - return NearComparator::Compare(expected, actual, error, - detailed_message); - break; - case C64: - return NearComparator::Compare(expected, actual, error, - detailed_message); - break; - default: - LOG(FATAL) << "Unsupported primitive type in near comparator: " - << PrimitiveType_Name(expected.shape().element_type()) - << ". Must be floating-point type."; - } - } - - // Non-floating point literal. - return LiteralTestUtil::Equal(expected, actual); +/* static */ ::testing::AssertionResult LiteralTestUtil::Equal( + const LiteralSlice& expected, const LiteralSlice& actual) { + return StatusToAssertion(literal_comparison::Equal(expected, actual)); } -} // namespace - /* static */ ::testing::AssertionResult LiteralTestUtil::Near( - const Literal& expected, const Literal& actual, const ErrorSpec& error, - bool detailed_message) { - return NearHelper(expected, actual, error, detailed_message, - /*shape_index=*/{}); -} - -/* static */ void LiteralTestUtil::ExpectNear(const Literal& expected, - const Literal& actual, - const ErrorSpec& error, - const string& message) { - ::testing::AssertionResult res = - Near(expected, actual, error, /*detailed_message=*/false); - if (!res) { - res << "Expected: " << TruncateHugeLiteral(expected) << "\n"; - res << "Actual: " << TruncateHugeLiteral(actual) << "\n"; - if (!message.empty()) { - res << StrCat("\nmessage: ", message); - } - } - EXPECT_TRUE(res); + const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error_spec, bool detailed_message) { + return StatusToAssertion(literal_comparison::Near( + expected, actual, error_spec, detailed_message, &OnMiscompare)); } -/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( - const Literal& expected, const Literal& actual, +/* static */ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( + const LiteralSlice& expected, const LiteralSlice& actual, const tensorflow::gtl::optional& error) { if (error.has_value()) { VLOG(1) << "Expects near"; - return Near(expected, actual, *error); + return StatusToAssertion(literal_comparison::Near( + expected, actual, *error, /*detailed_message=*/false, &OnMiscompare)); } VLOG(1) << "Expects equal"; - return Equal(expected, actual); -} - -/*static*/ void LiteralTestUtil::ExpectNearOrEqual( - const Literal& expected, const Literal& actual, - const tensorflow::gtl::optional& error) { - EXPECT_TRUE(NearOrEqual(expected, actual, error)); -} - -/* static */ string LiteralTestUtil::MultiIndexAsString( - tensorflow::gtl::ArraySlice multi_index) { - return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}"); -} - -/* static */ std::unique_ptr LiteralTestUtil::Reshape( - tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, const Literal& literal) { - int64 new_num_elements = 1; - for (int64 i = 0; i < new_dimensions.size(); ++i) { - new_num_elements *= new_dimensions[i]; - } - CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); - CHECK_EQ(new_dimensions.size(), minor_to_major.size()); - - auto new_literal = MakeUnique( - ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); - - // Create a new shape with the given minor-to-major layout. This shape is used - // solely for converting linear address to multi-dimensional addresses when - // writing elements to the new literal. - Shape shape_with_layout = new_literal->shape(); - *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); - - // Copy data into new literal, element-by-element. - for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { - std::vector from_multi_index = - IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); - std::vector to_multi_index = - IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); - switch (literal.shape().element_type()) { - case PRED: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case U8: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case U32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case S32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case U64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case S64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case F32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case F64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case C64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - default: - LOG(FATAL) << "Unhandled primitive element type: " - << PrimitiveType_Name(literal.shape().element_type()); - } - } - - return new_literal; + return StatusToAssertion(literal_comparison::Equal(expected, actual)); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index a755568c0f098e15512bd1d3720269c867bc9c49..d1b8a6cf0b2552f1b7d95a2560d502da14ddc39a 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/error_spec.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -38,282 +39,190 @@ limitations under the License. namespace xla { -// Structure describing permissible absolute and relative error bounds. -struct ErrorSpec { - explicit ErrorSpec(float aabs, float arel = 0, bool relaxed_nans = false) - : abs(aabs), rel(arel), relaxed_nans(relaxed_nans) {} - - float abs; // Absolute error bound. - float rel; // Relative error bound. - - // If relaxed_nans is true then any result is valid if we are expecting NaNs. - // In effect, this allows the tested operation to produce incorrect results - // for inputs outside its mathematical domain. - bool relaxed_nans; -}; - // Utility class for making expectations/assertions related to XLA literals. class LiteralTestUtil { public: // Asserts that the given shapes have the same rank, dimension sizes, and // primitive types. - static ::testing::AssertionResult EqualShapes(const Shape& expected, - const Shape& actual); - static void AssertEqualShapes(const Shape& expected, const Shape& actual); + static ::testing::AssertionResult EqualShapes( + const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT; // Asserts that the provided shapes are equal as defined in AssertEqualShapes // and that they have the same layout. - static void AssertEqualShapesAndLayouts(const Shape& expected, - const Shape& actual); - - // If the given literal's data type is bfloat16, converts it to a float - // literal; otherwise, returns a copy of it. If the literal is a tuple, - // recursively converts its elements. - static std::unique_ptr ConvertBF16ToF32(const Literal& bf16_literal); - - // If the given literal's data type is float, converts it to a bfloat16 - // literal; otherwise, returns a copy of it. If the literal is a tuple, - // recursively converts its elements. - static std::unique_ptr ConvertF32ToBF16(const Literal& f32_literal); - - // Asserts that the expected and actual literals are (bitwise) equal for all - // elements in the literal. Also, asserts that the rank, dimensions sizes, and - // primitive type are equal. - static ::testing::AssertionResult Equal( - const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT; + static ::testing::AssertionResult EqualShapesAndLayouts( + const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT; - // Expects that expected and actual are Equal. - static void ExpectEqual(const Literal& expected, const Literal& actual, - const string& message = ""); - - // Expects that expected and actual are Not Equal. - static void ExpectNotEqual(const Literal& expected, const Literal& actual); + static ::testing::AssertionResult Equal(const LiteralSlice& expected, + const LiteralSlice& actual) + TF_MUST_USE_RESULT; // Asserts the given literal are (bitwise) equal to given expected values. template - static void ExpectR0Equal(NativeT expected, const Literal& actual); + static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual); + template static void ExpectR1Equal(tensorflow::gtl::ArraySlice expected, - const Literal& actual); + const LiteralSlice& actual); template static void ExpectR2Equal( std::initializer_list> expected, - const Literal& actual); + const LiteralSlice& actual); + template static void ExpectR3Equal( std::initializer_list< std::initializer_list>> expected, - const Literal& actual); + const LiteralSlice& actual); // Asserts the given literal are (bitwise) equal to given array. template static void ExpectR2EqualArray2D(const Array2D& expected, - const Literal& actual); + const LiteralSlice& actual); template static void ExpectR3EqualArray3D(const Array3D& expected, - const Literal& actual); + const LiteralSlice& actual); template static void ExpectR4EqualArray4D(const Array4D& expected, - const Literal& actual); + const LiteralSlice& actual); - // Asserts that the expected and actual literals are within the given error - // bound for all elements. Also, asserts that the rank, dimensions sizes, and - // bounds are equivalent. + // Decorates literal_comparison::Near() with an AssertionResult return type. // - // Tuples are matched recursively. When comparing tensors of - // non-floating-point type, checks for exact equality, ignoring the ErrorSpec. - // - // If the shape of the literals is neither a complex/floating-point tensor nor - // a tuple which contains a complex/floating-point tensor, Near() is - // equivalent to Equal(). We don't raise an error in this case, because we - // want to allow callers to call Near() even if they have no preconceptions - // about the shapes being compared. - // - // If detailed_message is true, then the error message in the assertion result - // will contain a more detailed breakdown of mismatches. + // See comment on literal_comparison::Near(). static ::testing::AssertionResult Near( - const Literal& expected, const Literal& actual, const ErrorSpec& error, + const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error_spec, bool detailed_message = false) TF_MUST_USE_RESULT; - // Expects expected and actual to be Near with the given error. - static void ExpectNear(const Literal& expected, const Literal& actual, - const ErrorSpec& error, const string& message = ""); - // Asserts the given literal are within the given error bound of the given // expected values. Only supported for floating point values. template - static void ExpectR0Near(NativeT expected, const Literal& actual, + static void ExpectR0Near(NativeT expected, const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR1Near(tensorflow::gtl::ArraySlice expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR2Near( std::initializer_list> expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR3Near( std::initializer_list< std::initializer_list>> expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR4Near( std::initializer_list>>> expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); // Asserts the given literal are within the given error bound to the given // array. Only supported for floating point values. template static void ExpectR2NearArray2D(const Array2D& expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR3NearArray3D(const Array3D& expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR4NearArray4D(const Array4D& expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error); // If the error spec is given, returns whether the expected and the actual are // within the error bound; otherwise, returns whether they are equal. Tuples // will be compared recursively. static ::testing::AssertionResult NearOrEqual( - const Literal& expected, const Literal& actual, + const LiteralSlice& expected, const LiteralSlice& actual, const tensorflow::gtl::optional& error) TF_MUST_USE_RESULT; - // If the error spec is given, expects the expected and the actual to be near; - // otherwise, expects them to be equal. Tuples will be compared recursively. - static void ExpectNearOrEqual( - const Literal& expected, const Literal& actual, - const tensorflow::gtl::optional& error); - - // Returns a multi-dimensional index as a string. For example: '{7, 8}' will - // be returned for a 2-dimensional index with dimension 0 index equal to 7, - // dimension 1 equal to 8. - static string MultiIndexAsString( - tensorflow::gtl::ArraySlice multi_index); - - // Creates a literal with a new shape with the given new dimensions using the - // data in the given input literal. For reshaping purposes the (flat) data - // buffer of the input literal is assumed to have the given minor_to_major - // layout order. - static std::unique_ptr Reshape( - tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, - const Literal& literal); - - // Creates a literal with the supplied shape, and uses the provided value - // generator to populate the literal's values. - // Returns the new literal object, or an error Status if failed. - template < - PrimitiveType type, - typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, - const std::function)>& generator); - - // Creates a literal with the supplied shape, and initializes the literal - // values using a normal distribution with given mean and stddev standard - // deviation, and using the engine as entropy generator. - // Returns the new literal object, or an error Status if failed. - template < - PrimitiveType type, typename E, - typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, E* engine, T mean, T stddev); - - // Creates a literal with the supplied shape, and initializes the literal - // values using a normal distribution with given mean and stddev standard - // deviation. - // Returns the new literal object, or an error Status if failed. - template < - PrimitiveType type, - typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, T mean, T stddev); - private: TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil); }; template /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, - const Literal& actual) { - ExpectEqual(*Literal::CreateR0(expected), actual); + const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR0(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR1Equal( - tensorflow::gtl::ArraySlice expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR1(expected), actual); + tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR1(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2Equal( std::initializer_list> expected, - const Literal& actual) { - ExpectEqual(*Literal::CreateR2(expected), actual); + const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR2(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR3Equal( std::initializer_list>> expected, - const Literal& actual) { - ExpectEqual(*Literal::CreateR3(expected), actual); + const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR3(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2EqualArray2D( - const Array2D& expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR2FromArray2D(expected), actual); + const Array2D& expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR2FromArray2D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR3EqualArray3D( - const Array3D& expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR3FromArray3D(expected), actual); + const Array3D& expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR3FromArray3D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR4EqualArray4D( - const Array4D& expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR4FromArray4D(expected), actual); + const Array4D& expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR4FromArray4D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR0(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR0(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR1Near( - tensorflow::gtl::ArraySlice expected, const Literal& actual, + tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR1(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR1(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2Near( std::initializer_list> expected, - const Literal& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR2(expected), actual, error); + const LiteralSlice& actual, const ErrorSpec& error) { + EXPECT_TRUE(Near(*Literal::CreateR2(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR3Near( std::initializer_list>> expected, - const Literal& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR3(expected), actual, error); + const LiteralSlice& actual, const ErrorSpec& error) { + EXPECT_TRUE(Near(*Literal::CreateR3(expected), actual, error)); } template @@ -321,63 +230,29 @@ template std::initializer_list>>> expected, - const Literal& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR4(expected), actual, error); + const LiteralSlice& actual, const ErrorSpec& error) { + EXPECT_TRUE(Near(*Literal::CreateR4(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2NearArray2D( - const Array2D& expected, const Literal& actual, + const Array2D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR2FromArray2D(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR2FromArray2D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR3NearArray3D( - const Array3D& expected, const Literal& actual, + const Array3D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR3FromArray3D(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR3FromArray3D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR4NearArray4D( - const Array4D& expected, const Literal& actual, + const Array4D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR4FromArray4D(expected), actual, error); -} - -template -/* static */ StatusOr> -LiteralTestUtil::CreateRandomLiteral( - const Shape& shape, - const std::function)>& generator) { - using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - TF_RET_CHECK(shape.element_type() == type); - std::unique_ptr literal = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(literal.get()->Populate( - [&](tensorflow::gtl::ArraySlice indexes) { - return generator(indexes); - })); - return std::move(literal); -} - -template -/* static */ StatusOr> -LiteralTestUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, - T stddev) { - using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - std::normal_distribution generator(mean, stddev); - return CreateRandomLiteral( - shape, [&](tensorflow::gtl::ArraySlice /*indexes*/) { - return generator(*engine); - }); -} - -template -/* static */ StatusOr> -LiteralTestUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) { - std::minstd_rand0 engine; - return CreateRandomLiteral(shape, &engine, mean, stddev); + EXPECT_TRUE(Near(*Literal::CreateR4FromArray4D(expected), actual, error)); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index 9d619a77c7e8d6398b559e8f562cd7f8194e0811..bbac7285aefbb1f028fad152e4b7fe6af01e9f6d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -34,7 +34,7 @@ TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) { std::unique_ptr literal = Literal::MakeTuple({ Literal::CreateR0(42).get(), Literal::CreateR0(64).get(), }); - LiteralTestUtil::ExpectEqual(*literal, *literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal)); } TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { @@ -97,6 +97,15 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { } } +TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { + auto expected = Literal::CreateR1({1, 2, 3}); + auto actual = Literal::CreateR1({4, 5, 6}); + ::testing::AssertionResult result = + LiteralTestUtil::Equal(*expected, *actual); + EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}")); + EXPECT_THAT(result.message(), ::testing::HasSubstr("actual: {4, 5, 6}")); +} + TEST(LiteralTestUtilTest, NearComparatorR1) { auto a = Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 44c6811df84f49b6c1b24c11449939e2d375a9d1..96858c00d6bbe59b673a34e7d5ca261756709596 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -210,12 +210,12 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR2Equal( {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralView::Create(*result_literal, {1})); + LiteralSlice(*result_literal, {1})); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {2})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {2})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { @@ -239,16 +239,16 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); LiteralTestUtil::ExpectR2Equal( {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralView::Create(*result_literal, {0, 0})); + LiteralSlice(*result_literal, {0, 0})); LiteralTestUtil::ExpectR2Equal( {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralView::Create(*result_literal, {0, 1})); + LiteralSlice(*result_literal, {0, 1})); LiteralTestUtil::ExpectR2Equal( {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralView::Create(*result_literal, {0, 2})); + LiteralSlice(*result_literal, {0, 2})); } XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { @@ -274,9 +274,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { @@ -321,9 +321,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( {{56.0f, 46.0f}, {36.0f, 26.0f}}, - LiteralView::Create(*result_literal, {0})); + LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR1Equal( - {40.0f, 71.0f, 117.0f}, LiteralView::Create(*result_literal, {1})); + {40.0f, 71.0f, 117.0f}, LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { @@ -361,9 +361,9 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{-1.0, -2.0}, {-3.0, -4}}, LiteralView::Create(*result_literal, {0})); + {{-1.0, -2.0}, {-3.0, -4}}, LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR1Equal( - {264.0, 73.0, 133.0}, LiteralView::Create(*result_literal, {1})); + {264.0, 73.0, 133.0}, LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { @@ -391,16 +391,16 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { std::unique_ptr result_0_literal = ShapedBufferToLiteral(result_0); LiteralTestUtil::ExpectR2Equal( {{-1.0, -2.0}, {-3.0, -4.0}}, - LiteralView::Create(*result_0_literal, {0})); + LiteralSlice(*result_0_literal, {0})); LiteralTestUtil::ExpectR2Equal( - {{22.0, 6.0}, {8.0, 10}}, LiteralView::Create(*result_0_literal, {1})); + {{22.0, 6.0}, {8.0, 10}}, LiteralSlice(*result_0_literal, {1})); ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0}); std::unique_ptr result_1_literal = ShapedBufferToLiteral(result_1); LiteralTestUtil::ExpectR2Equal( - {{1.0, 2.0}, {3.0, 4.0}}, LiteralView::Create(*result_1_literal, {0})); + {{1.0, 2.0}, {3.0, 4.0}}, LiteralSlice(*result_1_literal, {0})); LiteralTestUtil::ExpectR2Equal( - {{44.0, 12.0}, {16.0, 20}}, LiteralView::Create(*result_1_literal, {1})); + {{44.0, 12.0}, {16.0, 20}}, LiteralSlice(*result_1_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { @@ -447,7 +447,7 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { for (int i = 0; i < kElementCount; ++i) { LiteralTestUtil::ExpectR1Near( - {2.0f * i, 0.0f}, LiteralView::Create(*result_literal, {i}), + {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_); } } @@ -502,7 +502,7 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { for (int i = 0; i < kFanout; ++i) { for (int j = 0; j < kFanout; ++j) { LiteralTestUtil::ExpectR0Near( - i + j + i * kFanout + j, LiteralView::Create(*result_literal, {i, j}), + i + j + i * kFanout + j, LiteralSlice(*result_literal, {i, j}), error_spec_); } } @@ -548,7 +548,7 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { index.push_back(0); } LiteralTestUtil::ExpectR0Equal( - 165.0, LiteralView::Create(*result_literal, index)); + 165.0, LiteralSlice(*result_literal, index)); } XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { @@ -754,9 +754,9 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); std::unique_ptr tuple_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR1Equal( - {2.0f, 4.0f, 6.0f}, LiteralView::Create(*tuple_literal, {0})); + {2.0f, 4.0f, 6.0f}, LiteralSlice(*tuple_literal, {0})); LiteralTestUtil::ExpectR1Equal( - {1.0f, 2.0f, 3.0f}, LiteralView::Create(*tuple_literal, {1})); + {1.0f, 2.0f, 3.0f}, LiteralSlice(*tuple_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index e859b3059eea86b362443c3269f99ccae941dfe2..88797a7d0a7d0567b3a380c5fb1ad0c0ee875587 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -35,9 +35,9 @@ namespace xla { /* static */ TestAllocator* LocalClientTestBase::allocator_; -StatusOr TestAllocator::Allocate(int device_ordinal, - uint64 size, - bool retry_on_failure) { +StatusOr TestAllocator::Allocate(int device_ordinal, + uint64 size, + bool retry_on_failure) { VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")"; { tensorflow::mutex_lock lock(count_mutex_); @@ -48,8 +48,7 @@ StatusOr TestAllocator::Allocate(int device_ordinal, retry_on_failure); } -tensorflow::Status TestAllocator::Deallocate(int device_ordinal, - se::DeviceMemoryBase* mem) { +Status TestAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { VLOG(2) << "Deallocate(" << device_ordinal << ")"; { tensorflow::mutex_lock lock(count_mutex_); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 3bbb760c806412a671bc2502846e123e2582fd16..258226523d830b40ecaa761df95988dc90f5ca47 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -46,10 +46,9 @@ class TestAllocator : public StreamExecutorMemoryAllocator { platform, PlatformUtil::GetStreamExecutors(platform).ValueOrDie()) { } - StatusOr Allocate(int device_ordinal, uint64 size, - bool retry_on_failure) override; - tensorflow::Status Deallocate(int device_ordinal, - se::DeviceMemoryBase* mem) override; + StatusOr Allocate(int device_ordinal, uint64 size, + bool retry_on_failure) override; + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; // Return the number of allocations that have been performed. int64 allocation_count() const; diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 464cc012140d4838de88c5bf5b3b2f1372c2c19b..27fd36e06acdc589f3a84ad561164e4a33b93506 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 0a603f4954badd12adf3144320789a5edd0d9c6c..ec7ca20bdf266cf8ed220809c0c24bee473359be 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -108,7 +107,7 @@ class MultiOutputFusionTest : public HloTestBase { expect.PopulateWithValue(size * 1.5f * 3.5f); auto actual = ExecuteAndTransfer( std::move(hlo_module), {Literal::CreateR0(-9.0f).get(), &arg1}); - LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); } void RunTest1D(bool manual_fusion, int size) { @@ -168,7 +167,7 @@ class MultiOutputFusionTest : public HloTestBase { Literal expect = std::move(*Literal::CreateR1({size * 1.5f * 3.5f})); auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); - LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); } }; @@ -211,5 +210,68 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { *result, *Literal::MakeTupleOwned(Literal::CreateR0(42)))); } +XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { + const char* testcase = R"( + HloModule m + + fused_computation { + p = f32[4] parameter(0) + multiply = f32[4] multiply(p, p) + less-than = pred[4] less-than(p, multiply) + ROOT tuple = (pred[4], f32[4]) tuple(less-than, multiply) + } + + ENTRY PredFloatMOF { + p0 = f32[4] parameter(0) + fusion = (pred[4], f32[4]) fusion(p0), kind=kLoop, calls=fused_computation + gte0 = pred[4] get-tuple-element(fusion), index=0 + gte1 = f32[4] get-tuple-element(fusion), index=1 + const = f32[4] constant({0, 0, 0, 0}) + ROOT select = f32[4] select(gte0, gte1, const) + })"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR1({1.0, 2.0, 3.0, -1.0}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::CreateR1({0.0, 4.0, 9.0, 1.0}))); +} + +XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { + const char* testcase = R"( + HloModule m + + fused_computation { + p = f32[] parameter(0) + multiply = f32[] multiply(p, p) + less-than = pred[] less-than(p, multiply) + ROOT tuple = (pred[], f32[]) tuple(less-than, multiply) + } + + map_computation { + p0 = f32[] parameter(0) + fusion = (pred[], f32[]) fusion(p0), kind=kLoop, calls=fused_computation + gte0 = pred[] get-tuple-element(fusion), index=0 + gte1 = f32[] get-tuple-element(fusion), index=1 + const = f32[] constant(0) + ROOT select = f32[] select(gte0, gte1, const) + } + + ENTRY MapMOF { + p1 = f32[3] parameter(0) + ROOT map = f32[3] map(p1), to_apply=map_computation + })"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR1({1.0, 2.0, 3.0}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::CreateR1({0.0, 4.0, 9.0}))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index 97dab860c06bddb2a0ffd45e48c4912c5f55d574..838f1b4e2f0f0e0871ec717bdeefcbbc653397e3 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" @@ -161,7 +160,7 @@ XLA_TEST_F(ParamsTest, MissingParameter) { auto p = builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "param2"); auto computation_status = builder.Build(); - ASSERT_NE(computation_status.status(), tensorflow::Status::OK()); + ASSERT_NE(computation_status.status(), Status::OK()); } XLA_TEST_F(ParamsTest, UnusedParameter) { diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 29a4f75001c688f2215745ab913df68bf2f62b76..1a2de6937c3e134852a730f62f7b56417cf49b28 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -273,11 +273,11 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { &execution_options_)); } - LiteralTestUtil::ExpectEqual(*result1, *result2); - LiteralTestUtil::ExpectEqual(*result1, *result3); - LiteralTestUtil::ExpectNotEqual(*result1, *result4); - LiteralTestUtil::ExpectNotEqual(*result4, *result5); - LiteralTestUtil::ExpectNotEqual(*result5, *result6); + EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2)); + EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3)); + EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4)); + EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5)); + EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6)); } XLA_TEST_F(PrngTest, TenValuesN01) { diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index bcc05c2d41d8439b021cdf6533b5ca87c19aec1f..d671d40456a276a44b462f390c95aa4af301263a 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 10a3da3a387641ec45baf02d15790e32371601fa..266760e8202fddc48792ac66dda334255e428808 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -356,12 +356,8 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); - std::unique_ptr arg_literal = Literal::CreateFromShape(shape); - auto generator = [&](tensorflow::gtl::ArraySlice indexes) -> float { - return 1.0f; - }; - TF_EXPECT_OK(arg_literal->Populate(generator)); - + auto arg_literal = MakeUnique(shape); + arg_literal->PopulateWithValue(1.0f); const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); Padding padding = Padding::kValid; @@ -371,13 +367,8 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector output_dims = {6, 8, 6, 6, 8, 8}; Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); - std::unique_ptr expected = Literal::CreateFromShape(result_shape); - auto out_generator = - [&](tensorflow::gtl::ArraySlice indexes) -> float { - return 27.0f; - }; - TF_EXPECT_OK(expected->Populate(out_generator)); - + auto expected = MakeUnique(result_shape); + expected->PopulateWithValue(27.0f); ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } @@ -1348,7 +1339,7 @@ INSTANTIATE_TEST_CASE_P( class ReduceWindowTextTest : public HloTestBase {}; TEST_F(ReduceWindowTextTest, R2General256x384) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R2Window mul { lhs = f32[] parameter(0) @@ -1365,7 +1356,7 @@ ENTRY R2Window { } TEST_F(ReduceWindowTextTest, R2General256x384Layout01) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R2Window mul { lhs = f32[] parameter(0) @@ -1382,7 +1373,7 @@ ROOT reduce-window = f32[256,384]{0,1} reduce-window(operand, constant), window= } TEST_F(ReduceWindowTextTest, R2General2x5) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R2Window mul { lhs = f32[] parameter(0) @@ -1399,7 +1390,7 @@ ENTRY R2Window { } TEST_F(ReduceWindowTextTest, R2EffectiveScalar) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R2Window mul { lhs = f32[] parameter(0) @@ -1417,7 +1408,7 @@ ENTRY R2Window { } TEST_F(ReduceWindowTextTest, R3EffectiveScalar) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R3Window mul { lhs = f32[] parameter(0) @@ -1435,7 +1426,7 @@ ENTRY R3Window { } TEST_F(HloTestBase, ReduceWindowIdentity) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule ReduceWindowIdentity identity.pad_to_reduce_window { param0 = f32[] parameter(0) @@ -1444,7 +1435,26 @@ identity.pad_to_reduce_window { ENTRY reduce-window-identity { operand = f32[1,32,64]{2,1,0} parameter(0) constant.4466 = f32[] constant(0) - ROOT reduce-window = f32[1,33,64]{2,1,0} reduce-window(operand, constant.4466), window={size=1x1x1 pad=0_0x1_0x0_0}, to_apply=identity.pad_to_reduce_window + ROOT reduce-window = f32[1,33,64]{2,1,0} reduce-window(operand, constant.4466), window={size=1x1x1 pad=0_0x1_0x0_0}, to_apply=identity.pad_to_reduce_window +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); +} + +TEST_F(HloTestBase, ReduceWindowS32) { + const string hlo_string = R"( +HloModule reduce-window + +%identity.pad_to_reduce_window (param0: s32[], param1: s32[]) -> s32[] { + %param0 = s32[] parameter(0) + ROOT %param1 = s32[] parameter(1) +} + +ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] { + %parameter.0 = s32[81,8]{1,0} parameter(0) + %parameter.1 = s32[] parameter(1) + ROOT %reduce-window = s32[82,8]{1,0} reduce-window(s32[81,8]{1,0} %parameter.0, s32[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window } )"; diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index 5ebd5268992846e80dcce2675f8e92038e190ecf..da1b588ec41cef711412367e89b2a9b1029bca71 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index d7462d581b8596dc43b81b0162b3f5020cebb546..a4580cd71d46ad0a0186eddd51291f9c322b6f49 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -656,9 +656,9 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { std::unique_ptr expected = Literal::CreateR2FromArray2D(expected_array); if (use_bfloat16()) { - expected = LiteralTestUtil::ConvertF32ToBF16(*expected); + expected = Literal::ConvertF32ToBF16(*expected); } - LiteralTestUtil::ExpectEqual(*expected, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); } XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { @@ -731,7 +731,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); std::unique_ptr expected = - LiteralTestUtil::Reshape({2, 1}, {1, 0}, *input_literal); + Literal::ReshapeSlice({2, 1}, {1, 0}, *input_literal); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, zero_error_spec_); } @@ -753,7 +753,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); std::unique_ptr expected = - LiteralTestUtil::Reshape({4, 2}, {1, 0}, *input_literal); + Literal::ReshapeSlice({4, 2}, {1, 0}, *input_literal); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, zero_error_spec_); } @@ -817,7 +817,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { // Since the reshape is a no-op, verify that it does not change the underlying // data. if (use_bfloat16()) { - auto expected = LiteralTestUtil::ConvertF32ToBF16(*input_literal); + auto expected = Literal::ConvertF32ToBF16(*input_literal); EXPECT_EQ(expected->data(), output_literal->data()); } else { EXPECT_EQ(input_literal->data(), output_literal->data()); @@ -886,7 +886,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -915,7 +915,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -944,7 +944,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -974,7 +974,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -1003,7 +1003,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal) + Literal::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal) ->Relayout(input_literal->shape().layout()); // Specify the requested output shape explicitly to ensure that this reshape diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index 8cbfcc6f5c4272706a0f9fd809041516bf32432b..7cfca781acda15879075f4386c2096e537877aac 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -100,7 +100,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { EXPECT_EQ(46.0f, actual->Get({1, 1})); std::unique_ptr round_tripped = RoundTripToServer(*actual); - LiteralTestUtil::ExpectEqual(*round_tripped, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); } TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { @@ -135,7 +135,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { EXPECT_EQ(46.0f, actual->Get({1, 1})); std::unique_ptr round_tripped = RoundTripToServer(*actual); - LiteralTestUtil::ExpectEqual(*round_tripped, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc index 32db45f8a66266712ba4091c2aa6368f0b822bd2..f334a8c1318a59bbfdd27dd1a63ed162600089ce 100644 --- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc @@ -41,7 +41,7 @@ class RoundTripTransferTest : public ClientLibraryTestBase { client_->TransferToServer(original).ConsumeValueOrDie(); std::unique_ptr result = client_->Transfer(*data).ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(original, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(original, *result)); } }; diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index f35bc43a4952137b4b6c94c771819e0514d4228f..308d3fc78a51e63c0e3db8c0cda18caf11f665bd 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -390,7 +390,7 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { &execution_options_) .ConsumeValueOrDie(); auto expected_literal = Literal::CreateR0(dividend / divisor); - LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } } @@ -431,7 +431,7 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { &execution_options_) .ConsumeValueOrDie(); auto expected_literal = Literal::CreateR0(dividend % divisor); - LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } } diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc index 3d694a9c3fe894107c3b0a8fc2e5d07310cb476c..72707f224446c7585d1d90ac6681a7b38c41d5f1 100644 --- a/tensorflow/compiler/xla/tests/select_test.cc +++ b/tensorflow/compiler/xla/tests/select_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 810cc25f1b5b1199984a3229909a70f9548c7dd2..de1865138802bc72e9a4b2db7a21343b0d327108 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -107,7 +107,7 @@ StatusOr> MakeFakeLiteralInternal( } return Literal::MakeTupleOwned(std::move(elements)); } - std::unique_ptr literal = Literal::CreateFromShape(shape); + auto literal = MakeUnique(shape); switch (shape.element_type()) { case BF16: PopulateWithRandomFloatingPointData(literal.get(), engine); diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index e2067bc1b835a946fc56801cbf227e05ef0686b4..0063e7ad415e9b6718c164f415ced6fb76cbf44a 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -175,7 +175,7 @@ XLA_TEST_F(TransferManagerTest, TransferTuple) { transfer_manager_->TransferLiteralFromDevice( stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { @@ -189,7 +189,7 @@ XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { transfer_manager_->TransferLiteralFromDevice( stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { @@ -209,7 +209,7 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { transfer_manager_->TransferLiteralFromDevice( stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValue) { @@ -224,7 +224,7 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValue) { transfer_manager_->TransferLiteralFromDevice( stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { @@ -243,7 +243,7 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { transfer_manager_->TransferLiteralFromDevice( stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc index 59ce23d0247b58c6aebc2b5a65453157c1ca15ff..fe1e3da7eca00e128377e6e56af877868aafa836 100644 --- a/tensorflow/compiler/xla/tests/transpose_test.cc +++ b/tensorflow/compiler/xla/tests/transpose_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" namespace xla { diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 5c287bac6a7cab5a3c2642971a5a67070ee56c72..41189231b90e842292830a932cf381af60456d4c 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" @@ -496,7 +495,7 @@ XLA_TEST_F(TupleTest, ComplexTuples) { auto sum = Literal::CreateR2({{{111, 222}, {331, 442}}, {{1011, 2022}, {3031, 4042}}, {{10011, 20022}, {30031, 40042}}}); - auto prod = Literal::CreateFromShape(sum->shape()); + auto prod = MakeUnique(sum->shape()); ASSERT_TRUE(prod->Populate( [&sum](tensorflow::gtl::ArraySlice indexes) { return sum->Get(indexes) * @@ -515,7 +514,7 @@ XLA_TEST_F(TupleTest, ComplexTuples) { class TupleHloTest : public HloTestBase {}; // Disabled on the interpreter because bitcast doesn't exist on the interpreter. -TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { +XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { const char* testcase = R"( HloModule m diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 50c8766f2e3976c7077046283ab3b3e762622fc5..c3abe22797f5eaa76ced2ad8534bd68c32983e60 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -84,6 +84,11 @@ int UnaryOpTest::inf() { return 2147483647; } +template <> +int64 UnaryOpTest::inf() { + return 0x7FFFFFFFFFFFFFFFl; +} + template <> void UnaryOpTest::AbsTestHelper() { XlaBuilder builder(TestName()); @@ -176,6 +181,7 @@ XLA_TEST_F(UnaryOpTest, SignTestR0) { XLA_TEST_F(UnaryOpTest, SignTestR1) { SignTestHelper(); + SignTestHelper(); SignTestHelper(); SignTestHelper(); } diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 7944b5132f3d11cf84488acbd920cc98c084072a..3c9a01653c67203cbc962a3d3d967142f7a2102c 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -84,8 +84,8 @@ Status ParseOneProfileOutputLine( string match_percentage = "\\d+\\.\\d\\d%"; string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)"; string match_usecs = "([0-9.]+) usec"; - string match_flops = "([^ ]+)"; - string match_trops = "([^ ]+)"; + string match_flops = "([^ ]*)"; + string match_trops = "([^ ]*)"; string match_bytes_per_sec = "([0-9.TGMKi]+)B/s"; string match_bytes_per_cycle = "([0-9.TGMKi]+)B/cycle"; diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index 6e3061b78a554f028b2ffae2e0590d91a4fe48e2..373c0d2d8d8ab05dec11e51f265d41b91e7920bf 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -30,7 +30,7 @@ limitations under the License. namespace xla { -/* static */ tensorflow::Status TextLiteralWriter::WriteToPath( +/* static */ Status TextLiteralWriter::WriteToPath( const Literal& literal, tensorflow::StringPiece path) { std::unique_ptr f; auto s = tensorflow::Env::Default()->NewWritableFile(std::string(path), &f); @@ -43,7 +43,7 @@ namespace xla { return s; } - tensorflow::Status status; + Status status; tensorflow::WritableFile* f_ptr = f.get(); literal.EachCellAsString( [f_ptr, &status](tensorflow::gtl::ArraySlice indices, diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h index 7375493f4309c9bf75fc9d724626267dff7ce5ed..0a1235b5e04675da0f412bafab6c4ecf04367787 100644 --- a/tensorflow/compiler/xla/text_literal_writer.h +++ b/tensorflow/compiler/xla/text_literal_writer.h @@ -37,8 +37,8 @@ namespace xla { // This should be readable by xla::TextLiteralReader. class TextLiteralWriter { public: - static tensorflow::Status WriteToPath(const Literal& literal, - tensorflow::StringPiece path); + static Status WriteToPath(const Literal& literal, + tensorflow::StringPiece path); private: TF_DISALLOW_COPY_AND_ASSIGN(TextLiteralWriter); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 78ab2dccafc37aa4f93da0b8d5b39a779ddd5db8..415cf9c16a2613913265d0342e5ab9932de5eb19 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -36,11 +36,10 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", ], ) @@ -63,10 +62,9 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) @@ -84,12 +82,10 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:testing", "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", @@ -165,12 +161,11 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:computation_tracker", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) @@ -184,12 +179,11 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) @@ -202,13 +196,12 @@ tf_cc_binary( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index 21ae8583d7cd3343230dcaff7dc17456e9e3e702..befb55453777dce30af89bcaad2ffe1647097576 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -17,7 +17,7 @@ limitations under the License. // // Dumps a graphviz URL for a snapshot computation to the command line. // -// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from // ServiceInterface::SnapshotComputation to disk. // // The GraphViz URL is placed into the log stderr, whereas computation @@ -30,11 +30,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -49,10 +48,11 @@ namespace tools { void RealMain(tensorflow::gtl::ArraySlice args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { - SessionModule module; + HloSnapshot module; TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + XlaComputation computation = + client->LoadSnapshot(module).ConsumeValueOrDie(); DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); ComputationStats stats = diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index b82f1c81c84b487c1661af5267b9123da97bb107..cfb8f37487d6499b803438a135be54524fcf17d2 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -21,11 +21,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -66,16 +65,16 @@ void RealMain(tensorflow::gtl::ArraySlice args) { LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); for (char* arg : args) { - SessionModule session_module; + HloSnapshot snapshot; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, - &session_module)); - auto computation_status = client->LoadSnapshot(session_module); + &snapshot)); + auto computation_status = client->LoadSnapshot(snapshot); if (!computation_status.ok()) { fprintf(stderr, "could not load snapshot for %s: %s\n", arg, computation_status.status().ToString().c_str()); continue; } - Computation computation = computation_status.ConsumeValueOrDie(); + XlaComputation computation = computation_status.ConsumeValueOrDie(); std::unique_ptr program_shape = client->GetComputationShape(computation).ConsumeValueOrDie(); @@ -89,8 +88,7 @@ void RealMain(tensorflow::gtl::ArraySlice args) { build_options.set_device_ordinal(0); build_options.set_result_layout(program_shape->result()); StatusOr> executable = - local_service->CompileExecutable(computation.handle(), layouts, - build_options); + local_service->CompileExecutable(computation, layouts, build_options); const HloModule& module = executable.ValueOrDie()->module(); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 05c0fdf97d27c09eb2bbb0f265b5b2a5982ca7b1..b815bbf854b82b323da7879c230a1026cae96625 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -19,11 +19,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -40,16 +39,16 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); for (char* arg : args) { - SessionModule session_module; + HloSnapshot snapshot; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, - &session_module)); - auto computation_status = client->LoadSnapshot(session_module); + &snapshot)); + auto computation_status = client->LoadSnapshot(snapshot); if (!computation_status.ok()) { fprintf(stderr, "could not load snapshot for %s: %s\n", arg, computation_status.status().ToString().c_str()); continue; } - Computation computation = computation_status.ConsumeValueOrDie(); + XlaComputation computation = computation_status.ConsumeValueOrDie(); if (compile) { std::unique_ptr program_shape = @@ -65,8 +64,7 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { build_options.set_device_ordinal(0); build_options.set_result_layout(program_shape->result()); StatusOr> executable = - local_service->CompileExecutable(computation.handle(), layouts, - build_options); + local_service->CompileExecutable(computation, layouts, build_options); const HloModule& module = executable.ValueOrDie()->module(); @@ -74,13 +72,11 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { local_service->backend().platform()->Name().c_str(), module.ToString(HloPrintOptions::ShortParsable()).c_str()); } else { - const ComputationTracker& tracker = local_service->computation_tracker(); - UserComputation* user_computation = - tracker.Resolve(computation.handle()).ConsumeValueOrDie(); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); + auto config = HloModule::CreateModuleConfigFromProto(computation.proto(), + DebugOptions()) + .ConsumeValueOrDie(); std::unique_ptr module = - tracker.BuildHloModule(versioned_handle, HloModuleConfig()) + HloModule::CreateFromProto(computation.proto(), config) .ConsumeValueOrDie(); fprintf(stdout, "%s\n", diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index 51f90b07c66f7d839f587350726333b9dbe6a9f0..a5dce20456c6a2402f425ebb3d575d1bb625f839 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -28,11 +28,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -48,10 +47,11 @@ namespace tools { void RealMain(tensorflow::gtl::ArraySlice args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { - SessionModule module; + HloSnapshot module; TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + XlaComputation computation = + client->LoadSnapshot(module).ConsumeValueOrDie(); DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); debug_options.set_xla_hlo_dump_as_graphdef(true); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 3a945fb3b1b54ea92e577a6bea5f771ac0e5defd..d0e7af8844203da93dac5b45cb7e13916448dd47 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -30,6 +30,7 @@ namespace { using tensorflow::StringPiece; using tensorflow::gtl::optional; +using tensorflow::str_util::Join; using tensorflow::str_util::Split; using tensorflow::str_util::SplitAndParseAsInts; using tensorflow::strings::Printf; @@ -53,7 +54,7 @@ class HloParser { std::unique_ptr ConsumeHloModule() { return std::move(module_); } // Returns the error information. - string GetError() const { return tensorflow::str_util::Join(error_, "\n"); } + string GetError() const { return Join(error_, "\n"); } private: // ParseXXX returns false if an error occurred. @@ -245,7 +246,7 @@ bool HloParser::Error(LocTy loc, StringPiece msg) { error_lines.push_back(std::string(lexer_.GetLine(loc))); error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); - error_.push_back(tensorflow::str_util::Join(error_lines, "\n")); + error_.push_back(Join(error_lines, "\n")); VLOG(1) << "Error: " << error_.back(); return false; } @@ -439,6 +440,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional metadata; attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata}; + optional backend_config; + attrs["backend_config"] = {/*required=*/false, AttrTy::kString, + &backend_config}; + HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { @@ -476,10 +481,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kFloor: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: @@ -1093,8 +1100,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, instruction->set_name(name); - // Add common attrs (sharding, control predecessors) to the instruction, if - // they were seen. + // Add shared attributes like metadata to the instruction, if they were seen. if (sharding) { instruction->set_sharding( HloSharding::FromProto(sharding.value()).ValueOrDie()); @@ -1111,6 +1117,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (metadata) { instruction->set_metadata(*metadata); } + if (backend_config) { + instruction->set_backend_config(std::move(*backend_config)); + } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) @@ -1488,11 +1497,10 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, std::vector elems_seen_until_dim(elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); return StrCat("[", - tensorflow::str_util::Join( - elems_seen_until_dim, ",", - [](string* out, const int64& num_elems) { - tensorflow::strings::StrAppend(out, num_elems - 1); - }), + Join(elems_seen_until_dim, ",", + [](string* out, const int64& num_elems) { + tensorflow::strings::StrAppend(out, num_elems - 1); + }), "]"); }; do { @@ -1680,7 +1688,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, return Error( index_loc, StrCat("invalid multi-dimension index for shape with rank ", rank, - ": [", tensorflow::str_util::Join(index, ", "), "]")); + ": [", Join(index, ", "), "]")); } } if (!ParseToken(TokKind::kColon, @@ -1848,7 +1856,19 @@ bool HloParser::ParseAttributeHelper( } auto attr_it = attrs.find(name); if (attr_it == attrs.end()) { - return Error(loc, Printf("unexpected attribute %s", name.c_str())); + string allowed_attrs; + if (attrs.empty()) { + allowed_attrs = "No attributes are allowed here."; + } else { + allowed_attrs = StrCat( + "Allowed attributes: ", + Join(attrs, ", ", + [&](string* out, const std::pair& kv) { + StrAppend(out, kv.first); + })); + } + return Error(loc, Printf("unexpected attribute \"%s\". %s", name.c_str(), + allowed_attrs.c_str())); } AttrTy attr_type = attr_it->second.attr_type; void* attr_out_ptr = attr_it->second.result; diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 4e085bc89c6dc6021a5b9cb1c5a57f0282f41ee1..131aded95ab04c4327c275ed8cd18b8fc7ac1bd6 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -65,7 +65,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { R"(HloModule constant_pred_module ENTRY %constant_pred () -> pred[] { - ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68} + ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}, backend_config="foo\" bar" } )" @@ -81,13 +81,14 @@ ENTRY %constant_s32 () -> s32[] { )" }, -// f32 constant, but the value is not a decimal +// f32 constant, but the value is not a decimal and there is a backend +// configuration { "ConstantF32", R"(HloModule ConstantF32_module ENTRY %ConstantF32.v4 () -> f32[] { - ROOT %constant = f32[] constant(42) + ROOT %constant = f32[] constant(42), backend_config="this is a configuration" } )" @@ -937,13 +938,13 @@ INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest, TEST_F(HloParserTest, Empty) { const string original = ""; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, Garbage) { const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, WrongOpcode) { @@ -957,7 +958,7 @@ ENTRY %blabla (x: f32[], y: f32[]) -> f32[] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, WrongShape) { @@ -969,7 +970,7 @@ ENTRY %blabla (x: g32[]) -> g32[] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, WrongOperandsSize) { @@ -982,7 +983,7 @@ ENTRY %blabla (x: f32[]) -> pred[] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, OperandNotFound) { @@ -993,7 +994,7 @@ ENTRY %blabla (x: f32[]) -> pred[] { } )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, MoreConstants) { @@ -1013,6 +1014,19 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] { // but the constant names will not be exactly the same. } +TEST_F(HloParserTest, ConfigurationField) { + const string original = R"(HloModule AModule +ENTRY %configuration_test() -> s32[] { + %constant = s32[] constant(42), backend_config="foo bar" +})"; + auto result = Parse(original); + TF_ASSERT_OK(result.status()); + EXPECT_EQ("foo bar", result.ValueOrDie() + ->entry_computation() + ->root_instruction() + ->backend_config()); +} + TEST_F(HloParserTest, LiteralDimensionsMismatch_1) { const string original = R"(HloModule some_2_module @@ -1022,7 +1036,7 @@ ENTRY %some_2 () -> f32[2] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects nested array in rank 1, but sees larger"); } @@ -1036,7 +1050,7 @@ ENTRY %some_2x3 () -> f32[2,3] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects nested array in rank 2, but sees 1"); } @@ -1050,7 +1064,7 @@ ENTRY %some_2x3x2 () -> f32[2,3,2] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects 3 elements in the [0]th element"); } @@ -1065,7 +1079,7 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "is out of range for literal's primitive type F16"); } @@ -1092,7 +1106,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} } )"; @@ -1138,7 +1152,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { )"; ExpectHasSubstr(Parse(original).status().error_message(), - "unexpected attribute calls"); + "unexpected attribute \"calls\""); } TEST_F(HloParserTest, MissingAttribute) { diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index d8cedad65ea68ef86b94394a1accf2c08517c0b2..df0501386c1e4de9111fbb6b2d9e8ec372dbf41e 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -17,7 +17,7 @@ limitations under the License. // // Replays computations and shows the results on the command line. // -// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from // ServiceInterface::SnapshotComputation to disk. // // Computations that require arguments can be replayed using fake data by @@ -36,14 +36,12 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/testing.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -76,13 +74,9 @@ struct Options { // // Similarly, infeeds fake data of shape fake_infeed_shape if it is provided; // otherwise, no infeed is performed. -template -StatusOr> ReplayComputation(const ModuleT& module, +StatusOr> ReplayComputation(const HloSnapshot& module, Client* client, const Options& opts) { - static_assert(std::is_same::value || - std::is_same::value, - "Proto must be in HloSnapshot or SessionModule format"); TF_ASSIGN_OR_RETURN(auto computation, client->LoadSnapshot(module)); std::vector> arguments; @@ -161,40 +155,13 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { for (char* arg : args) { HloSnapshot snapshot; auto status = tensorflow::ReadBinaryProto(env, arg, &snapshot); - if (status.ok()) { - StatusOr> result_status = - ReplayComputation(snapshot, client, opts); - if (!result_status.ok()) { - fprintf(stderr, "%s: error: %s\n", arg, - result_status.status().ToString().c_str()); - exit_status = EXIT_FAILURE; - continue; - } - - std::unique_ptr result = result_status.ConsumeValueOrDie(); - if (result != nullptr) { - fprintf(stdout, "%s: %s :: %s:%s\n", arg, - snapshot.hlo().hlo_module().name().c_str(), - ShapeUtil::HumanString(result->shape()).c_str(), - result->ToString().c_str()); - if (snapshot.has_result()) { - std::unique_ptr literal = - Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); - fprintf(stdout, "was %s:%s\n", - ShapeUtil::HumanString(snapshot.result().shape()).c_str(), - literal->ToString().c_str()); - } - } - + if (!status.ok()) { + fprintf(stderr, "%s: is not HloSnapshot: %s.\n", arg, + status.ToString().c_str()); continue; } - fprintf(stderr, "%s: is not HloSnapshot: %s. Trying as SessionModule...\n", - arg, status.ToString().c_str()); - - SessionModule module; - TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module)); StatusOr> result_status = - ReplayComputation(module, client, opts); + ReplayComputation(snapshot, client, opts); if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", arg, result_status.status().ToString().c_str()); @@ -204,14 +171,15 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { std::unique_ptr result = result_status.ConsumeValueOrDie(); if (result != nullptr) { - fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(), + fprintf(stdout, "%s: %s :: %s:%s\n", arg, + snapshot.hlo().hlo_module().name().c_str(), ShapeUtil::HumanString(result->shape()).c_str(), result->ToString().c_str()); - if (module.has_result()) { + if (snapshot.has_result()) { std::unique_ptr literal = - Literal::CreateFromProto(module.result()).ConsumeValueOrDie(); + Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); fprintf(stdout, "was %s:%s\n", - ShapeUtil::HumanString(module.result().shape()).c_str(), + ShapeUtil::HumanString(snapshot.result().shape()).c_str(), literal->ToString().c_str()); } } diff --git a/tensorflow/compiler/xla/tools/show_signature.cc b/tensorflow/compiler/xla/tools/show_signature.cc index 1f3340cbc6afa9bda8bf639d01b8185968f79a4d..4e53fafcc97ff53afc5713e7ed8ee5222fac316b 100644 --- a/tensorflow/compiler/xla/tools/show_signature.cc +++ b/tensorflow/compiler/xla/tools/show_signature.cc @@ -18,7 +18,7 @@ limitations under the License. // Shows the signature (ProgramShape) of binary snapshot proto(s) on the command // line. // -// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from // ServiceInterface::SnapshotComputation to disk. // // The output format is: @@ -31,9 +31,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -49,13 +48,14 @@ namespace tools { void RealMain(tensorflow::gtl::ArraySlice args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { - SessionModule module; + HloSnapshot module; TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + auto computation = client->LoadSnapshot(module).ConsumeValueOrDie(); std::unique_ptr shape = client->GetComputationShape(computation).ConsumeValueOrDie(); - fprintf(stdout, "%s: %s :: %s\n", arg, module.entry().name().c_str(), + fprintf(stdout, "%s: %s :: %s\n", arg, + module.hlo().hlo_module().name().c_str(), ShapeUtil::HumanString(*shape).c_str()); } } diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 750d72d797b4f8680e13597ac02f6f9fa6e37bcd..b895ac045c361b2336e0081eadf16334d49d3bee 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -814,6 +814,12 @@ enum UnaryOperation { // Elementwise, computes clz(x). UNOP_CLZ = 17; + + // Elementwise, computes exp(x)-1. + UNOP_EXPM1 = 18; + + // Elementwise, computes log(x+1). + UNOP_LOG1P = 19; } message UnaryOpRequest { diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index abdbdb4cd22ff38a0fae89af10c600a178d9a3d4..0f9c80404ad33c39ae783e0bfa3cfb26e342fe3d 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -71,6 +71,7 @@ py_library( "//tensorflow/contrib/memory_stats:memory_stats_py", "//tensorflow/contrib/meta_graph_transform", "//tensorflow/contrib/metrics:metrics_py", + "//tensorflow/contrib/mixed_precision:mixed_precision", "//tensorflow/contrib/model_pruning", "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 9f5459f41da3e5a13286f7002e4b519978bc189b..9aad772f0acd941d50d6ba238d345616195a6939 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -60,6 +60,7 @@ from tensorflow.contrib import lookup from tensorflow.contrib import losses from tensorflow.contrib import memory_stats from tensorflow.contrib import metrics +from tensorflow.contrib import mixed_precision from tensorflow.contrib import model_pruning from tensorflow.contrib import nccl from tensorflow.contrib import nn diff --git a/tensorflow/contrib/android/BUILD b/tensorflow/contrib/android/BUILD index 60306ebdc6cddb04e8807bfd495fa92a56e55ecd..c10179ba8b290b6209f5567d6323df4bcf711585 100644 --- a/tensorflow/contrib/android/BUILD +++ b/tensorflow/contrib/android/BUILD @@ -72,7 +72,7 @@ cc_binary( "-s", "-Wl,--gc-sections", "-Wl,--version-script", # This line must be directly followed by LINKER_SCRIPT. - LINKER_SCRIPT, + "$(location {})".format(LINKER_SCRIPT), ]), linkshared = 1, linkstatic = 1, diff --git a/tensorflow/contrib/autograph/CONTRIBUTING.md b/tensorflow/contrib/autograph/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..a7a3fe1452d2a3e9c2a37a25ae96f541f8f939e0 --- /dev/null +++ b/tensorflow/contrib/autograph/CONTRIBUTING.md @@ -0,0 +1,45 @@ +# How to Contribute + +We'd love to have your patches and contributions! Here are some guidelines. In general, we follow the [TensorFlow contributing guidelines](../../CONTRIBUTING.md), but have some [AutoGraph-specific style guidelines](STYLE_GUIDE.md). More details below. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution; +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult [GitHub +Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +After a pull request is approved, we merge it. Note our merging process differs +from GitHub in that we pull and submit the change into an internal version +control system. This system automatically pushes a git commit to the GitHub +repository (with credit to the original author) and closes the pull request. + +## Style + +See the [TensorFlow AutoGraph style guide](STYLE_GUIDE.md). + +## Unit tests + +Please include unit tests when contributing new features ([example here](converters/continue_statements_test.py)), as they help to a) prove that your code works correctly, and b) guard against future breaking +changes to lower the maintenance cost. +It's also helpful to check that any +changes you propose do not break existing unit tests. You can run tests using the command, + +```shell +bazel test --config=opt --copt=-O3 --copt=-march=native \ + //tensorflow/contrib/autograph/... +``` + +from the root of the `tensorflow` repository. For more details see the [main TensorFlow Contributing File](../../CONTRIBUTING.md) diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md index 0ba99c396fc1c8ee1e12fbb4fe0293ee52ed9bc9..674859bed4ec157d5d5b33b6fc015c930e54b392 100644 --- a/tensorflow/contrib/autograph/README.md +++ b/tensorflow/contrib/autograph/README.md @@ -1,6 +1,6 @@ # AutoGraph -IMPORTANT: AutoGraph is pre-alpha, under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! +IMPORTANT: AutoGraph is alpha software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)). AutoGraph is a Python to TensorFlow compiler. diff --git a/tensorflow/contrib/autograph/STYLE_GUIDE.md b/tensorflow/contrib/autograph/STYLE_GUIDE.md new file mode 100644 index 0000000000000000000000000000000000000000..5618ec3e34499ad0f0b2a0d8b0ad04c11ee9bf9c --- /dev/null +++ b/tensorflow/contrib/autograph/STYLE_GUIDE.md @@ -0,0 +1,125 @@ +# TensorFlow AutoGraph Style Guide + +This page contains style decisions that both developers and users of TensorFlow +AutoGraph should follow to increase the readability of their code, reduce the +number of errors, and promote consistency. We borrow many style principles from the TensorFlow Probability style guide. + +## TensorFlow Style + +Follow the [TensorFlow style +guide](https://www.tensorflow.org/community/style_guide) and [documentation +guide](https://www.tensorflow.org/community/documentation). Below are additional +TensorFlow conventions not noted in those guides. In the future, these noted +conventions may be moved upstream. + +1. The name is TensorFlow, not Tensorflow. +2. The name is AutoGraph, not Autograph. + +## TensorFlow Code of Conduct +Please review and follow the [TensorFlow Code of Conduct](../../CODE_OF_CONDUCT.md). + +## TensorFlow AutoGraph Style + +Below are TensorFlow AutoGraph-specific conventions. In the event of conflict, +it supercedes all previous conventions. + +1. __Importing submodule aliases.__ Use the Pythonic style +`from tensorflow.contrib.autograph.converters import ifexp` and `from tensorflow.contrib import autograph as ag`. + +2. __Examples in Docstrings.__ Write a `#### Examples` subsection below `Args`, + `Returns`, `Raises`, etc. to illustrate examples. If the docstring's last + line is a fence bracket (\`\`\`) closing a code snippet, add an empty line + before closing the docstring with \"\"\". This properly displays the code + snippet. + + Justification: Users regularly need to remind themselves of args and + semantics. But rarely look at examples more than the first time. But since + examples are usually long (which is great!) it means they have to do a lot + of annoying scrolling ...unless Examples follow Args/Returns/Raises. + +3. __Citations in Docstrings.__ Write a `#### References` subsection at the + bottom of any docstring with citations. Use ICLR’s bibliography style to + write references; for example, order entries by the first author's last + name. Add a link to the paper if the publication is open source (ideally, + arXiv). + + Write in-paragraph citations in general, e.g., [(Tran and Blei, 2018)][1]. + Write in-text citations when the citation is a noun, e.g., [Tran and Blei + (2018)][1]. Write citations with more than two authors using et al., e.g., + [(Tran et al., 2018)][1]. Separate multiple citations with semicolon, e.g., + ([Tran and Blei, 2018][1]; [Gelman and Rubin, 1992][2]). + + Examples: + + ```none + #### References + + # technical report + [1]: Tony Finch. Incremental calculation of weighted mean and variance. + _Technical Report_, 2009. + http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf + + # journal + [2]: Andrew Gelman and Donald B. Rubin. Inference from Iterative Simulation + Using Multiple Sequences. _Statistical Science_, 7(4):457-472, 1992. + + # arXiv preprint + # use "et al." for papers with too many authors to maintain + [3]: Aaron van den Oord et al. Parallel WaveNet: Fast High-Fidelity Speech + Synthesis. _arXiv preprint arXiv:1711.10433_, 2017. + https://arxiv.org/abs/1711.10433 + + # conference + [4]: Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, and Roger Grosse. + Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches. In _International Conference on Learning + Representations_, 2018. + https://arxiv.org/abs/1803.04386 + ``` + +4. When doing float math over literals eg use `1.` instead of `1` or `1.0`. + + * Using `1.` is another line of defense against an automatic casting + mistake. (Using `1.0` is also such a defense but is not minimal.) + +5. Prefer using named args for functions' 2nd args onward. + + * Definitely use named args for 2nd args onward in docstrings. + +9. Avoid LaTeX in docstrings. + + * It is not rendered in many (if not most) editors and can be hard to read + for both LaTeX experts and non-experts. + +10. Write docstring and comment math using ASCII friendly notation; python using + operators. E.g., `x**2` better than `x^2`, `x[i, j]` better than `x_{i,j}`, + `sum{ f(x[i]) : i=1...n }` better than `\sum_{i=1}^n f(x_i)` `int{sin(x) dx: + x in [0, 2 pi]}` better than `\int_0^{2\pi} sin(x) dx`. + + * The more we stick to python style, the more someone can + copy/paste/execute. + * Python style is usually easier to read as ASCII. + +11. All public functions require docstrings with: one line description, Args, + Returns, Raises (if raises exceptions). + + * Returns docstrings should be in the same format as Args, eg, of the form + "name: Description." Part of the rationale is that we are suggesting a + reasonable variable name for the returned object(s). + +12. Regard `*args` and/or `**kwargs` as features of last resort. + + * Keyword arguments make the intention of a function call more clear. + * [Possible exceptions for + `kwargs`](https://stackoverflow.com/questions/1415812/why-use-kwargs-in-python-what-are-some-real-world-advantages-over-using-named). + +18. The `__init__.py` file for modules should use TensorFlow's + `remove_undocumented` feature, which seals the module's methods. + +21. Use `"{}".format()` rather than `"" %` for string formatting. + + Justification: [PEP 3101](https://www.python.org/dev/peps/pep-3101/) and + [Python official + tutorials](https://docs.python.org/3.2/tutorial/inputoutput.html#old-string-formatting): + "...this old style of formatting will eventually be removed from the + language, str.format() should generally be used." diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py index 1be1c96dd31bf05b746fae6a2b02774e20ca0c4f..35877224b87c1abda1a270be4869e9dcfd0cf97c 100644 --- a/tensorflow/contrib/autograph/converters/break_statements.py +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gast + from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import templates from tensorflow.contrib.autograph.pyct import transformer @@ -52,8 +54,13 @@ class BreakStatementTransformer(transformer.Base): def _guard_if_present(self, block, var_name): """Prevents the block from executing if var_name is set.""" + + # If we don't have statements that immediately depend on the break + # we still need to make sure that the break variable remains + # used, in case the break becomes useful in later stages of transformation. + # Not having this broke the break_in_inner_loop test. if not block: - return block + block = [gast.Pass()] template = """ if not var_name: block diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py index 554f0471d44d54194c45c3855b1483796ae65a6a..b6ecdcb7809b1ad7e7461324cb6a110ef4180609 100644 --- a/tensorflow/contrib/autograph/converters/call_trees.py +++ b/tensorflow/contrib/autograph/converters/call_trees.py @@ -292,15 +292,25 @@ class CallTreeTransformer(transformer.Base): raise NotImplementedError( 'py_func with return values (unknown function)') else: + if anno.hasanno(node.func, anno.Basic.QN): + # Special-case a few builtins that otherwise go undetected. This + # normally doesn't pose a problem, but the dict built-in doesn't + # work with inspect.getargspec which is required for dynamic functions. + # Note: expecting this is resilient to aliasing (e.g. + # dict = an_evil_dict), because in those cases the regular mechanisms + # process a simple user function. + qn = anno.getanno(node.func, anno.Basic.QN) + # Add items to this list as needed. + if str(qn) in ('dict',): + return node + if ast_util.matches(node, 'super(_)'): # super() calls are preserved. The class conversion mechanism will # ensure that they return the correct value. - pass - elif self.context.recursive: + return node + + if self.context.recursive: node = self._insert_dynamic_conversion(node) - else: - # Unresolved functions are allowed in non-recursive mode. - pass return node diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py index 935a2786db0289c67860be2da97e3f554f12500c..d7ddbe8a04f64848d6ec21155d8d85f60e19d276 100644 --- a/tensorflow/contrib/autograph/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Handles control flow statements: while, if.""" +"""Handles control flow statements: while, for, if.""" from __future__ import absolute_import from __future__ import division @@ -25,6 +25,7 @@ from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis import cfg from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -47,9 +48,6 @@ class SymbolNamer(object): class ControlFlowTransformer(transformer.Base): """Transforms control flow structures like loops an conditionals.""" - def __init__(self, context): - super(ControlFlowTransformer, self).__init__(context) - def _create_cond_branch(self, body_name, aliased_orig_names, aliased_new_names, body, returns): if aliased_orig_names: @@ -98,30 +96,63 @@ class ControlFlowTransformer(transformer.Base): body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) - - if body_scope.created - orelse_scope.created: - raise ValueError( - 'The if branch creates new symbols that the else branch does not.') - if orelse_scope.created - body_scope.created: - raise ValueError( - 'The else branch creates new symbols that the if branch does not.') - - modified = tuple(body_scope.modified | orelse_scope.modified) - all_referenced = body_scope.referenced | orelse_scope.referenced + body_defs = body_scope.created | body_scope.modified + orelse_defs = orelse_scope.created | orelse_scope.modified + live = anno.getanno(node, 'live_out') + + # We'll need to check if we're closing over variables that are defined + # elsewhere in the function + # NOTE: we can only detect syntactic closure in the scope + # of the code passed in. If the AutoGraph'd function itself closes + # over other variables, this analysis won't take that into account. + defined = anno.getanno(node, 'defined_in') + + # We only need to return variables that are + # - modified by one or both branches + # - live (or has a live parent) at the end of the conditional + modified = [] + for def_ in body_defs | orelse_defs: + def_with_parents = set((def_,)) | def_.support_set + if live & def_with_parents: + modified.append(def_) + + # We need to check if live created variables are balanced + # in both branches + created = live & (body_scope.created | orelse_scope.created) + + # The if statement is illegal if there are variables that are created, + # that are also live, but both branches don't create them. + if created: + if created != (body_scope.created & live): + raise ValueError( + 'The main branch does not create all live symbols that the else ' + 'branch does.') + if created != (orelse_scope.created & live): + raise ValueError( + 'The else branch does not create all live symbols that the main ' + 'branch does.') # Alias the closure variables inside the conditional functions # to avoid errors caused by the local variables created in the branch # functions. - need_alias = ( - (body_scope.modified | orelse_scope.modified) - - (body_scope.created | orelse_scope.created)) - aliased_orig_names = tuple(need_alias) - aliased_new_names = tuple( - self.context.namer.new_symbol(s.ssf(), all_referenced) - for s in aliased_orig_names) - alias_map = dict(zip(aliased_orig_names, aliased_new_names)) - node_body = ast_util.rename_symbols(node.body, alias_map) - node_orelse = ast_util.rename_symbols(node.orelse, alias_map) + # We will alias variables independently for body and orelse scope, + # because different branches might write different variables. + aliased_body_orig_names = tuple(body_scope.modified - body_scope.created) + aliased_orelse_orig_names = tuple(orelse_scope.modified - + orelse_scope.created) + aliased_body_new_names = tuple( + self.context.namer.new_symbol(s.ssf(), body_scope.referenced) + for s in aliased_body_orig_names) + aliased_orelse_new_names = tuple( + self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced) + for s in aliased_orelse_orig_names) + + alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) + alias_orelse_map = dict( + zip(aliased_orelse_orig_names, aliased_orelse_new_names)) + + node_body = ast_util.rename_symbols(node.body, alias_body_map) + node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) if not modified: # When the cond would return no value, we leave the cond called without @@ -134,26 +165,47 @@ class ControlFlowTransformer(transformer.Base): else: results = gast.Tuple([s.ast() for s in modified], None) - body_name = self.context.namer.new_symbol('if_true', all_referenced) - orelse_name = self.context.namer.new_symbol('if_false', all_referenced) + body_name = self.context.namer.new_symbol('if_true', body_scope.referenced) + orelse_name = self.context.namer.new_symbol('if_false', + orelse_scope.referenced) if modified: - body_returns = tuple( - alias_map[s] if s in aliased_orig_names else s for s in modified) + + def build_returns(aliased_names, alias_map, scope): + """Builds list of return variables for a branch of a conditional.""" + returns = [] + for s in modified: + if s in aliased_names: + returns.append(alias_map[s]) + else: + if s not in scope.created | defined: + raise ValueError( + 'Attempting to return variable "%s" from the true branch of ' + 'a conditional, but it was not closed over, or created in ' + 'this branch.' % str(s)) + else: + returns.append(s) + return tuple(returns) + + body_returns = build_returns(aliased_body_orig_names, alias_body_map, + body_scope) + orelse_returns = build_returns(aliased_orelse_orig_names, + alias_orelse_map, orelse_scope) + else: - body_returns = templates.replace('tf.ones(())')[0].value + body_returns = orelse_returns = templates.replace('tf.ones(())')[0].value body_def = self._create_cond_branch( body_name, - aliased_orig_names=tuple(aliased_orig_names), - aliased_new_names=tuple(aliased_new_names), + aliased_orig_names=tuple(aliased_body_orig_names), + aliased_new_names=tuple(aliased_body_new_names), body=node_body, returns=body_returns) orelse_def = self._create_cond_branch( orelse_name, - aliased_orig_names=tuple(aliased_orig_names), - aliased_new_names=tuple(aliased_new_names), + aliased_orig_names=tuple(aliased_orelse_orig_names), + aliased_new_names=tuple(aliased_orelse_new_names), body=node_orelse, - returns=body_returns) + returns=orelse_returns) cond_expr = self._create_cond_expr(results, node.test, body_name, orelse_name) @@ -284,6 +336,7 @@ class ControlFlowTransformer(transformer.Base): def transform(node, context): - t = ControlFlowTransformer(context) - node = t.visit(node) + cfg.run_analyses(node, cfg.Liveness(context)) + cfg.run_analyses(node, cfg.Defined(context)) + node = ControlFlowTransformer(context).visit(node) return node diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py index c5610b16b4e5de374f404307d3583660707d5e0b..1a863590f97add9bfa587d1142a09ae26a9fdb44 100644 --- a/tensorflow/contrib/autograph/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -22,6 +22,7 @@ from tensorflow.contrib.autograph.converters import control_flow from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import test @@ -95,6 +96,91 @@ class ControlFlowTest(converter_test_base.TestCase): with self.test_session() as sess: self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1)))) + def test_imbalanced_aliasing(self): + + def test_fn(n): + if n > 0: + n = 3 + return n + + node = self.parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond) as result: + with self.test_session() as sess: + self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(2)))) + self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3)))) + + def test_ignore_unread_variable(self): + + def test_fn(n): + b = 3 # pylint: disable=unused-variable + if n > 0: + b = 4 + return n + + node = self.parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: + with self.test_session() as sess: + self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(3)))) + self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3)))) + + def test_handle_temp_variable(self): + + def test_fn_using_temp(x, y, w): + if x < y: + z = x + y + else: + w = 2 + tmp = w + z = x - tmp + return z, w + + node = self.parse_and_analyze(test_fn_using_temp, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: + with self.test_session() as sess: + z, w = sess.run( + result.test_fn_using_temp( + constant_op.constant(-3), constant_op.constant(3), + constant_op.constant(3))) + self.assertEqual(0, z) + self.assertEqual(3, w) + z, w = sess.run( + result.test_fn_using_temp( + constant_op.constant(3), constant_op.constant(-3), + constant_op.constant(3))) + self.assertEqual(1, z) + self.assertEqual(2, w) + + def test_fn_ignoring_temp(x, y, w): + if x < y: + z = x + y + else: + w = 2 + tmp = w + z = x - tmp + return z + + node = self.parse_and_analyze(test_fn_ignoring_temp, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: + with self.test_session() as sess: + z = sess.run( + result.test_fn_ignoring_temp( + constant_op.constant(-3), constant_op.constant(3), + constant_op.constant(3))) + self.assertEqual(0, z) + z = sess.run( + result.test_fn_ignoring_temp( + constant_op.constant(3), constant_op.constant(-3), + constant_op.constant(3))) + self.assertEqual(1, z) + def test_simple_for(self): def test_fn(l): diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py index 5edd8e74a8899a25fb51e2a4e133f3cb7933fa26..bc61498b5422f5e130bbfeef935d0a796b4f5922 100644 --- a/tensorflow/contrib/autograph/impl/conversion_test.py +++ b/tensorflow/contrib/autograph/impl/conversion_test.py @@ -24,7 +24,7 @@ from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.impl import api from tensorflow.contrib.autograph.impl import conversion from tensorflow.python.framework import constant_op -from tensorflow.python.keras._impl.keras.engine import training +from tensorflow.python.keras.engine import training from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD index 83f3bafc4217649db6499566d548c1657428ad0b..8064a967cd389e88d3febbeb21cac87b0fef9e18 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD +++ b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD @@ -19,6 +19,7 @@ py_library( srcs = [ "activity.py", "annos.py", + "cfg.py", "live_values.py", "type_info.py", ], @@ -43,6 +44,19 @@ py_test( ], ) +py_test( + name = "cfg_test", + srcs = ["cfg_test.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + ":static_analysis", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) + py_test( name = "live_values_test", srcs = ["live_values_test.py"], diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..ad97fdfa8e78d1fd4c38724612d83519c6609cce --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py @@ -0,0 +1,445 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Control flow graph analysis. + +Given a Python AST we construct a control flow graph, with edges both to the +next and previous statements (so it can easily walk the graph both ways). Its +nodes contain the AST of the statements. It can then perform forward or backward +analysis on this CFG. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple +import functools +import operator + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct.static_analysis import activity + + +class CfgNode(object): + """A node in the CFG.""" + __slots__ = ['next', 'value', 'prev'] + + def __init__(self, value): + self.next = set() + self.prev = set() + self.value = value + + +class Cfg(namedtuple('Cfg', ['entry', 'exit'])): + """A Control Flow Graph. + + Each statement is represented as a node. For control flow statements such + as conditionals and loops the conditional itself is a node which either + branches or cycles, respectively. + Attributes: + entry: The entry node, which contains the `gast.arguments` node of the + function definition. + exit: The exit node. This node is special because it has no value (i.e. no + corresponding AST node). This is because Python functions can have + multiple return statements. + """ + pass + + +class CfgBuilder(gast.NodeVisitor): + """Construct a control flow graph. + + Construct a CFG starting from a FunctionDef node. + Usage: + cfg_obj = CfgBuilder().build_cfg(fndef_node) + """ + + def __init__(self): + # The current leaves of the CFG + self.current_leaves = [] + # TODO(alexbw): generalize to break, return, continue, yield, etc. + # A stack of lists, tracking continue statements + self.continue_ = [] + # A stack of lists tracking break nodes + self.break_ = [] + + def set_current_leaves(self, cfg_node): + """Link this cfg_node to the current leaves. + + This is the central function for building the CFG. It links the current + head cfg_nodes to the passed cfg_node. It then resets the head to the + passed cfg_node. + + Args: + cfg_node: A CfgNode instance. + """ + for head in self.current_leaves: + head.next.add(cfg_node) + # While we're linking the CFG forward, add backlinks + cfg_node.prev.add(head) + self.current_leaves = [cfg_node] + + def build_cfg(self, node): + """Build a CFG for a function. + + Implementation of building a CFG for dataflow analysis. See, e.g.: + https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec02-Dataflow.pdf + + Args: + node: A function definition the body of which to analyze. + Returns: + A CFG object. + Raises: + TypeError: If the input is not a function definition. + """ + if not isinstance(node, gast.FunctionDef): + raise TypeError('input must be a function definition') + entry_cfg_node = CfgNode(node.args) + self.current_leaves = [entry_cfg_node] + self.visit_statements(node.body) + exit_cfg_node = CfgNode(None) + self.set_current_leaves(exit_cfg_node) + return Cfg(entry_cfg_node, exit_cfg_node) + + def visit_statements(self, nodes): + for node in nodes: + # Check for control flow + if isinstance(node, (gast.For, gast.While, gast.If, gast.Try, gast.Break, + gast.Continue, gast.With)): + self.visit(node) + else: + expr = CfgNode(node) + self.set_current_leaves(expr) + + def generic_visit(self, node): + raise ValueError('unknown control flow') + + def visit_If(self, node): + # TODO(alexbw): change this to use immutable tuples instead of lists + # The current head will hold the conditional + test = CfgNode(node.test) + self.set_current_leaves(test) + # Handle the body + self.visit_statements(node.body) + body_exit = self.current_leaves + self.current_leaves = [test] + # Handle the orelse + self.visit_statements(node.orelse) + self.current_leaves.extend(body_exit) + + def visit_While(self, node): + test = CfgNode(node.test) + self.set_current_leaves(test) + # Start a new level of nesting + self.break_.append([]) + self.continue_.append([]) + # Handle the body + self.visit_statements(node.body) + body_exit = self.current_leaves + self.current_leaves.extend(self.continue_.pop()) + self.set_current_leaves(test) + # Handle the orelse + self.visit_statements(node.orelse) + # The break statements and the test go to the next node + self.current_leaves.extend(self.break_.pop()) + # Body and orelse statements can reach out of the loop + self.current_leaves.extend(body_exit) + + def visit_For(self, node): + iter_ = CfgNode(node.iter) + self.set_current_leaves(iter_) + self.break_.append([]) + self.continue_.append([]) + self.visit_statements(node.body) + body_exit = self.current_leaves + self.current_leaves.extend(self.continue_.pop()) + self.set_current_leaves(iter_) + # Handle the orelse + self.visit_statements(node.orelse) + # The break statements and the test go to the next node + self.current_leaves.extend(self.break_.pop()) + # Body and orelse statements can reach out of the loop + self.current_leaves.extend(body_exit) + + def visit_Break(self, node): + self.break_[-1].extend(self.current_leaves) + self.current_leaves[:] = [] + + def visit_Continue(self, node): + self.continue_[-1].extend(self.current_leaves) + self.current_leaves[:] = [] + + def visit_Try(self, node): + self.visit_statements(node.body) + body = self.current_leaves + handlers = [] + for handler in node.handlers: + self.current_leaves = body[:] + self.visit_statements(handler.body) + handlers.extend(self.current_leaves) + self.current_leaves = body + self.visit_statements(node.orelse) + self.current_leaves = handlers + self.current_leaves + self.visit_statements(node.finalbody) + + def visit_With(self, node): + for item in node.items: + self.set_current_leaves(CfgNode(item)) + self.visit_statements(node.body) + + +# TODO(alexbw): once CFG analysis occurs at a block level, +# this extra class will not be necessary +class PropagateAnalysis(gast.NodeVisitor): + """Port analysis annotations from statements to their enclosing blocks.""" + + def __init__(self, analysis): + self.transfer_fn = analysis.transfer_fn + self.in_label = analysis.in_label + self.out_label = analysis.out_label + super(PropagateAnalysis, self).__init__() + + def visit_If(self, node): + # Depth-first. + self.generic_visit(node) + incoming = anno.getanno(node.body[0], self.in_label) + incoming |= anno.getanno(node.test, self.in_label) + outgoing = anno.getanno(node.body[-1], self.out_label) + outgoing |= anno.getanno(node.test, self.out_label) + if node.orelse: + orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) + outgoing = self.transfer_fn(outgoing, orelse_outgoing) + anno.setanno(node, self.in_label, incoming) + anno.setanno(node, self.out_label, outgoing) + + def visit_For(self, node): + self.generic_visit(node) + incoming = set(anno.getanno(node.body[0], self.in_label)) + incoming -= set((anno.getanno(node.target, anno.Basic.QN),)) + outgoing = anno.getanno(node.body[-1], self.out_label) + if node.orelse: + orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) + outgoing = self.transfer_fn(outgoing, orelse_outgoing) + anno.setanno(node, self.in_label, frozenset(incoming)) + anno.setanno(node, self.out_label, outgoing) + + def visit_While(self, node): + self.generic_visit(node) + incoming = anno.getanno(node.body[0], self.in_label) + incoming |= anno.getanno(node.test, self.in_label) + outgoing = anno.getanno(node.body[-1], self.out_label) + if node.orelse: + orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) + outgoing = self.transfer_fn(outgoing, orelse_outgoing) + anno.setanno(node, self.in_label, incoming) + anno.setanno(node, self.out_label, outgoing) + + def visit_With(self, node): + self.generic_visit(node) + incoming = anno.getanno(node.body[0], self.in_label) + for item in node.items: + incoming |= anno.getanno(item, self.in_label) + outgoing = anno.getanno(node.body[-1], self.out_label) + anno.setanno(node, self.in_label, incoming) + anno.setanno(node, self.out_label, outgoing) + + +# TODO(alexbw): Abstract the CFG walking machinery into a superclass +# which is parameterized on which fields it selects when walking. +# TODO(alexbw): Abstract the application of dataflow analysis +class Forward(object): + """Forward analysis on CFG. + + Args: + label: A name for this analysis e.g. 'active' for activity analysis. The AST + nodes in the CFG will be given annotations 'name_in', 'name_out', + 'name_gen' and 'name_kill' which contain the incoming values, outgoing + values, values generated by the statement, and values deleted by the + statement respectively. + transfer_fn: Either the AND or OR operator. If the AND operator is used it + turns into forward must analysis (i.e. a value will only be carried + forward if it appears on all incoming paths). The OR operator means that + forward may analysis is done (i.e. the union of incoming values will be + taken). + """ + + def __init__(self, label, context, transfer_fn=operator.or_): + self.transfer_fn = transfer_fn + self.context = context + self.out_label = label + '_out' + self.in_label = label + '_in' + self.gen_label = label + '_gen' + self.kill_label = label + '_kill' + + # TODO(alexbw): see if we can simplify by visiting breadth-first + def visit(self, node): + """Depth-first walking the CFG, applying dataflow information propagtion.""" + # node.value is None only for the exit CfgNode. + if not node.value: + return + + if anno.hasanno(node.value, self.out_label): + before = hash(anno.getanno(node.value, self.out_label)) + else: + before = None + preds = [ + anno.getanno(pred.value, self.out_label) + for pred in node.prev + if anno.hasanno(pred.value, self.out_label) + ] + if preds: + incoming = functools.reduce(self.transfer_fn, preds[1:], preds[0]) + else: + incoming = frozenset() + anno.setanno(node.value, self.in_label, incoming) + gen, kill = self.get_gen_kill(node, incoming) + anno.setanno(node.value, self.gen_label, gen) + anno.setanno(node.value, self.kill_label, kill) + anno.setanno(node.value, self.out_label, (incoming - kill) | gen) + + if hash(anno.getanno(node.value, self.out_label)) != before: + for succ in node.next: + self.visit(succ) + + def get_gen_kill(self, cfg_node, incoming): + """Calculate Gen and Kill properties of a CFG node in dataflow analysis. + + A function which takes the CFG node as well as a set of incoming + values. It must return a set of newly generated values by the statement as + well as a set of deleted (killed) values. + + Args: + cfg_node: A CfgNode instance. + incoming: + """ + raise NotImplementedError() + + +class Backward(Forward): + """Backward analysis on CFG.""" + + def visit(self, cfg_node): + # cfg_node.value is None for the exit node, which will be visited only once + if not cfg_node.value: + for pred in cfg_node.prev: + self.visit(pred) + return + + if anno.hasanno(cfg_node.value, self.in_label): + before = hash(anno.getanno(cfg_node.value, self.in_label)) + else: + before = None + succs = [ + anno.getanno(succ.value, self.in_label) + for succ in cfg_node.next + if anno.hasanno(succ.value, self.in_label) + ] + if succs: + incoming = functools.reduce(self.transfer_fn, succs[1:], succs[0]) + else: + incoming = frozenset() + anno.setanno(cfg_node.value, self.out_label, incoming) + gen, kill = self.get_gen_kill(cfg_node, incoming) + anno.setanno(cfg_node.value, self.gen_label, gen) + anno.setanno(cfg_node.value, self.kill_label, kill) + anno.setanno(cfg_node.value, self.in_label, (incoming - kill) | gen) + if hash(anno.getanno(cfg_node.value, self.in_label)) != before: + for pred in cfg_node.prev: + self.visit(pred) + + +def run_analyses(node, analyses): + """Perform dataflow analysis on all functions within an AST. + + Args: + node: An AST node on which to run dataflow analysis. + analyses: Either an instance of the Forward or Backward dataflow analysis + class, or a list or tuple of them. + + Returns: + node: The node, but now with annotations on the AST nodes containing the + results of the dataflow analyses. + """ + if not isinstance(analyses, (tuple, list)): + analyses = (analyses,) + for analysis in analyses: + if not isinstance(analysis, (Forward, Backward)): + raise TypeError('not a valid forward analysis object') + + for child_node in gast.walk(node): + if isinstance(child_node, gast.FunctionDef): + cfg_obj = CfgBuilder().build_cfg(child_node) + for analysis in analyses: + if isinstance(analysis, Backward): + analysis.visit(cfg_obj.exit) + elif isinstance(analysis, Forward): + analysis.visit(cfg_obj.entry) + for analysis in analyses: + PropagateAnalysis(analysis).visit(node) + return node + + +class Liveness(Backward): + """Perform a liveness analysis. + + Each statement is annotated with a set of variables that may be used + later in the program. + """ + + def __init__(self, context): + super(Liveness, self).__init__('live', context) + + def get_gen_kill(self, node, _): + # A variable's parents are live if it is live + # e.g. x is live if x.y is live. This means gen needs to return + # all parents of a variable (if it's an Attribute or Subscript). + # This doesn't apply to kill (e.g. del x.y doesn't affect liveness of x) + gen = activity.get_read(node.value, self.context) + gen = functools.reduce(lambda left, right: left | right.support_set, gen, + gen) + kill = activity.get_updated(node.value, self.context) + return gen, kill + + +class ReachingDefinitions(Forward): + """Perform reaching definition analysis. + + Each statement is annotated with a set of (variable, definition) pairs. + """ + + def __init__(self, context): + super(ReachingDefinitions, self).__init__('definitions', context) + + def get_gen_kill(self, node, incoming): + definitions = activity.get_updated(node.value, self.context) + gen = frozenset((id_, node.value) for id_ in definitions) + kill = frozenset(def_ for def_ in incoming if def_[0] in definitions) + return gen, kill + + +class Defined(Forward): + """Perform defined variable analysis. + + Each statement is annotated with a set of variables which are guaranteed to + be defined at that point. + """ + + def __init__(self, context): + super(Defined, self).__init__('defined', context, transfer_fn=operator.and_) + + def get_gen_kill(self, node, _): + gen = activity.get_updated(node.value, self.context) + return gen, frozenset() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fc07fa3447b23c0595a5893329de8a2d7055ca15 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py @@ -0,0 +1,306 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cfg module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import cfg +from tensorflow.python.platform import test + + +class CFGTest(test.TestCase): + + def _parse_and_analyze(self, test_fn, namespace, arg_types=None): + arg_types = arg_types or {} + node, source = parser.parse_entity(test_fn) + ctx = context.EntityContext( + namer=None, + source_code=source, + source_file=None, + namespace=namespace, + arg_values=None, + arg_types=arg_types, + owner_type=None, + recursive=True) + node = qual_names.resolve(node) + return node, ctx + + def _check_anno_matches(self, node, anno_name, var_names): + if isinstance(var_names, str): + var_names = (var_names,) + qual_vars = set() + for var_name in var_names: + if isinstance(var_name, str): + if '[' in var_name or ']' in var_name: + raise ValueError('Annotation matching not supported with subscript.') + if '.' not in var_name: + qual_vars.add(qual_names.QN(var_name)) + else: + attrs = var_name.split('.') + this_qn = functools.reduce(qual_names.QN, attrs[1:], + qual_names.QN(attrs[0])) + qual_vars.add(this_qn) + self.assertEqual(anno.getanno(node, anno_name), qual_vars) + + def test_reaching(self): + + def f(x): + print(x) + while True: + x = x + x = x + return x + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) + body = node.body[0].body + # Only the argument reaches the expression + def_in = anno.getanno(body[0], 'definitions_in') + # One element, x, from arguments + self.assertEqual(set(type(d[1]) for d in def_in), set((gast.arguments,))) + + while_body = body[1].body + def_in = anno.getanno(while_body[0], 'definitions_in') + # One definition, two possible sources. + # - One from an assignment (if the loop is entered) + # - The other from the arguments (if loop is not entered) + self.assertEqual( + set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign))) + + def_in = anno.getanno(while_body[1], 'definitions_in') + # If we've reached this line, the only reaching definition of x is the + # Assign node in previous line + self.assertEqual(set(type(d[1]) for d in def_in), set((gast.Assign,))) + + def_in = anno.getanno(body[2], 'definitions_in') + # Same situation as while_body[0] + self.assertEqual( + set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign))) + + def test_defined(self): + + def f(x): + if x: + y = 2 # pylint: disable=unused-variable + return x + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Defined(ctx)) + body = node.body[0].body + # only x is for sure defined at the end + self._check_anno_matches(body[1], 'defined_in', 'x') + # at the end of the if body both x and y are defined + if_body = body[0].body + self._check_anno_matches(if_body[0], 'defined_out', ('x', 'y')) + + def _get_live_annotated_fnbody(self, f): + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Liveness(ctx)) + body = node.body[0].body + return body + + def test_live_straightline(self): + + def f1(x): + a = g(x) # pylint: disable=undefined-variable + b = h(a) # pylint: disable=undefined-variable, unused-variable + return x + + body = self._get_live_annotated_fnbody(f1) + self._check_anno_matches(body[1], 'live_in', ('a', 'h', 'x')) + self._check_anno_matches(body[2], 'live_in', ('x')) + self._check_anno_matches(body[0], 'live_in', ('g', 'h', 'x')) + self._check_anno_matches(body[2], 'live_out', ()) + + def test_live_stacked_conds_with_else(self): + + def f2(x, a): # pylint: disable=unused-argument + if a > 0: # x should not be live + x = 0 + if a > 1: + x = 1 + else: + x = 2 + + body = self._get_live_annotated_fnbody(f2) + self._check_anno_matches(body[0], 'live_in', ('a')) + self._check_anno_matches(body[1], 'live_in', ('a')) + + def test_live_stacked_conds(self): + + def f3(x, a): + if a > 0: # x and a should be live + x = 0 + if a > 1: # x and a should be live_in + x = 1 + return x # x should be live + + body = self._get_live_annotated_fnbody(f3) + self._check_anno_matches(body[0], 'live_in', ('a', 'x')) + self._check_anno_matches(body[1], 'live_in', ('a', 'x')) + self._check_anno_matches(body[2], 'live_in', ('x')) + + def test_live_possibly_unused_cond(self): + + def f4(x, a): + if a > 0: # x should be live + x = 0 + x += 1 + + body = self._get_live_annotated_fnbody(f4) + self._check_anno_matches(body[0], 'live_in', ('x', 'a')) + self._check_anno_matches(body[1], 'live_in', ('x')) + + def test_live_attribute_in_cond(self): + + def f5(x, a): + if a > 0: # x.y should be live + x.y = 0 + return x.y + + body = self._get_live_annotated_fnbody(f5) + self._check_anno_matches(body[0], 'live_in', ('x', 'x.y', 'a')) + + def test_live_noop(self): + + def f6(x): + return x # should this cause x.* to be live? + + body = self._get_live_annotated_fnbody(f6) + self._check_anno_matches(body[0], 'live_in', ('x')) + + def test_live_loop(self): + + def f7(x, n): + for i in range(n): + x += i + return x + + body = self._get_live_annotated_fnbody(f7) + self._check_anno_matches(body[0], 'live_in', ('x', 'n', 'range')) + self._check_anno_matches(body[1], 'live_in', ('x')) + + def test_live_context_manager(self): + + def f8(x, f): + with f: + x += 1 + + body = self._get_live_annotated_fnbody(f8) + self._check_anno_matches(body[0], 'live_in', ('f', 'x')) + + def test_node_equality(self): + node_a = gast.parse('y = x').body[0] + node_b = gast.parse('y = x').body[0] + self.assertNotEqual(node_a, node_b) + + def test_nested_functions_defined(self): + + def f(x): + y = x * 2 + + def g(z): + return z + y + + return g(x) + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Defined(ctx)) + + body = node.body[0].body + self.assertEqual( + anno.getanno(body[2], 'defined_in'), + frozenset(map(qual_names.QN, ('g', 'x', 'y')))) + + # TODO(alexbw): CFG analysis doesn't currently cross FunctionDef boundaries. + # NOTE: 'z' is easy to find, but 'y' is not identified as + # defined, because CFG analysis is applied with each function separately. + # fndef_body = body[1].body + # self.assertEqual( + # anno.getanno(fndef_body[0], 'defined_in'), + # frozenset(map(qual_names.QN, ('z', 'y')))) + + def test_nested_functions_dont_leak_definitions(self): + + def f(x): + print(x) + + def g(): + y = 2 + return y + + return g() # y is not defined here + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Defined(ctx)) + body = node.body[0].body + self.assertEqual( + anno.getanno(body[2], 'defined_in'), + frozenset(map(qual_names.QN, ('x', 'g')))) + + def test_loop_else(self): + + # Disabling useless-else-on-loop error, because 'break' and 'continue' + # canonicalization are a separate analysis pass, and here we test + # the CFG analysis in isolation. + def for_orelse(x): + y = 0 + for i in range(len(x)): + x += i + else: # pylint: disable=useless-else-on-loop + y = 1 + return x, y + + def while_orelse(x, i): + y = 0 + while x < 10: + x += i + else: # pylint: disable=useless-else-on-loop + y = 1 + return x, y + + for f in (for_orelse, while_orelse): + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) + body = node.body[0].body + return_node = body[-1] + reaching_defs = anno.getanno(return_node, 'definitions_in') + + # Y could be defined by Assign(Num(0)) or Assign(Num(1)) + # X could be defined as an argument or an AugAssign. + y_defs = [node for var, node in reaching_defs if str(var) == 'y'] + x_defs = [node for var, node in reaching_defs if str(var) == 'x'] + + self.assertEqual(set((gast.Assign,)), set(type(def_) for def_ in y_defs)) + self.assertEqual(set((0, 1)), set(def_.value.n for def_ in y_defs)) + self.assertEqual(len(y_defs), 2) + self.assertEqual( + set((gast.arguments, gast.AugAssign)), + set(type(def_) for def_ in x_defs)) + self.assertEqual(len(x_defs), 2) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index 9994c84ebdb930eea0818188225488eb5eca84eb..758754feac31f1d2cf10e69d7a9a6d288931c900 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -45,6 +45,7 @@ from tensorflow.python.training import training_util _DNN_LEARNING_RATE = 0.001 + def _get_optimizer(optimizer): if callable(optimizer): return optimizer() @@ -73,6 +74,7 @@ def _dnn_tree_combined_model_fn(features, dnn_input_layer_partitioner=None, dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, use_core_versions=False): @@ -108,6 +110,8 @@ def _dnn_tree_combined_model_fn(features, as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. @@ -132,8 +136,7 @@ def _dnn_tree_combined_model_fn(features, dnn_parent_scope = "dnn" dnn_partitioner = dnn_input_layer_partitioner or ( partitioned_variables.min_max_variable_partitioner( - max_partitions=config.num_ps_replicas, - min_slice_size=64 << 20)) + max_partitions=config.num_ps_replicas, min_slice_size=64 << 20)) with variable_scope.variable_scope( dnn_parent_scope, @@ -171,8 +174,7 @@ def _dnn_tree_combined_model_fn(features, _add_hidden_layer_summary(net, hidden_layer_scope.name) previous_layer = net with variable_scope.variable_scope( - "logits", - values=(previous_layer,)) as logits_scope: + "logits", values=(previous_layer,)) as logits_scope: dnn_logits = layers.fully_connected( previous_layer, head.logits_dimension, @@ -190,8 +192,7 @@ def _dnn_tree_combined_model_fn(features, optimizer=_get_optimizer(dnn_optimizer), name=dnn_parent_scope, variables=ops.get_collection( - ops.GraphKeys.TRAINABLE_VARIABLES, - scope=dnn_parent_scope), + ops.GraphKeys.TRAINABLE_VARIABLES, scope=dnn_parent_scope), # Empty summaries to prevent optimizers from logging training_loss. summaries=[]) @@ -230,7 +231,10 @@ def _dnn_tree_combined_model_fn(features, update_op = state_ops.assign_add(global_step, 1).op return update_op - tree_train_logits = dnn_logits + tree_logits + if predict_with_tree_only: + tree_train_logits = tree_logits + else: + tree_train_logits = dnn_logits + tree_logits def _no_train_op_fn(loss): """Returns a no-op.""" @@ -288,10 +292,10 @@ def _dnn_tree_combined_model_fn(features, finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() model_fn_ops.training_hooks.extend([ - trainer_hooks.SwitchTrainOp( - dnn_train_op, dnn_steps_to_train, tree_train_op), - trainer_hooks.StopAfterNTrees( - num_trees, attempted_trees, finalized_trees)]) + trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train, + tree_train_op), + trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, finalized_trees) + ]) return model_fn_ops @@ -318,6 +322,7 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): dnn_input_layer_partitioner=None, dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, use_core_versions=False): @@ -360,6 +365,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. @@ -377,16 +384,32 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, - tree_learner_config, num_trees, tree_examples_per_layer, config, - dnn_optimizer, dnn_activation_fn, dnn_dropout, - dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, tree_feature_columns, tree_center_bias, - use_core_versions) + features=features, + labels=labels, + mode=mode, + head=head, + dnn_hidden_units=dnn_hidden_units, + dnn_feature_columns=dnn_feature_columns, + tree_learner_config=tree_learner_config, + num_trees=num_trees, + tree_examples_per_layer=tree_examples_per_layer, + config=config, + dnn_optimizer=dnn_optimizer, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + dnn_input_layer_partitioner=dnn_input_layer_partitioner, + dnn_input_layer_to_tree=dnn_input_layer_to_tree, + dnn_steps_to_train=dnn_steps_to_train, + predict_with_tree_only=predict_with_tree_only, + tree_feature_columns=tree_feature_columns, + tree_center_bias=tree_center_bias, + use_core_versions=use_core_versions) super(DNNBoostedTreeCombinedClassifier, self).__init__( - model_fn=_model_fn, model_dir=model_dir, - config=config, feature_engineering_fn=feature_engineering_fn) + model_fn=_model_fn, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) class DNNBoostedTreeCombinedRegressor(estimator.Estimator): @@ -410,6 +433,7 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): dnn_input_layer_partitioner=None, dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, use_core_versions=False): @@ -452,6 +476,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. @@ -474,16 +500,32 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, - tree_learner_config, num_trees, tree_examples_per_layer, config, - dnn_optimizer, dnn_activation_fn, dnn_dropout, - dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, tree_feature_columns, tree_center_bias, - use_core_versions) + features=features, + labels=labels, + mode=mode, + head=head, + dnn_hidden_units=dnn_hidden_units, + dnn_feature_columns=dnn_feature_columns, + tree_learner_config=tree_learner_config, + num_trees=num_trees, + tree_examples_per_layer=tree_examples_per_layer, + config=config, + dnn_optimizer=dnn_optimizer, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + dnn_input_layer_partitioner=dnn_input_layer_partitioner, + dnn_input_layer_to_tree=dnn_input_layer_to_tree, + dnn_steps_to_train=dnn_steps_to_train, + predict_with_tree_only=predict_with_tree_only, + tree_feature_columns=tree_feature_columns, + tree_center_bias=tree_center_bias, + use_core_versions=use_core_versions) super(DNNBoostedTreeCombinedRegressor, self).__init__( - model_fn=_model_fn, model_dir=model_dir, - config=config, feature_engineering_fn=feature_engineering_fn) + model_fn=_model_fn, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) class DNNBoostedTreeCombinedEstimator(estimator.Estimator): @@ -508,6 +550,7 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): dnn_input_layer_partitioner=None, dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, use_core_versions=False): @@ -545,6 +588,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. @@ -553,15 +598,32 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. """ + def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, - tree_learner_config, num_trees, tree_examples_per_layer, config, - dnn_optimizer, dnn_activation_fn, dnn_dropout, - dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, tree_feature_columns, tree_center_bias, - use_core_versions) + features=features, + labels=labels, + mode=mode, + head=head, + dnn_hidden_units=dnn_hidden_units, + dnn_feature_columns=dnn_feature_columns, + tree_learner_config=tree_learner_config, + num_trees=num_trees, + tree_examples_per_layer=tree_examples_per_layer, + config=config, + dnn_optimizer=dnn_optimizer, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + dnn_input_layer_partitioner=dnn_input_layer_partitioner, + dnn_input_layer_to_tree=dnn_input_layer_to_tree, + dnn_steps_to_train=dnn_steps_to_train, + predict_with_tree_only=predict_with_tree_only, + tree_feature_columns=tree_feature_columns, + tree_center_bias=tree_center_bias, + use_core_versions=use_core_versions) super(DNNBoostedTreeCombinedEstimator, self).__init__( - model_fn=_model_fn, model_dir=model_dir, - config=config, feature_engineering_fn=feature_engineering_fn) + model_fn=_model_fn, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 44a8ffaf4b2f5a9c11b3abc46ce55a18c80ad318..04e32267cc4a00b3169c3abbcbf549805a0fb462 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -422,6 +422,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats(*gradients_t, *hessians_t, bucket_idx); } present_gradient_stats *= normalizer_ratio; + GradientStats not_present = + root_gradient_stats - present_gradient_stats; + // If there was (almost) no sparsity, fix the default direction to LEFT. + bool fixed_default_direction = not_present.IsAlmostZero(); GradientStats left_gradient_stats; for (int64 element_idx = start_index; element_idx < end_index; @@ -441,6 +445,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { // backward pass gradients. GradientStats right_gradient_stats = present_gradient_stats - left_gradient_stats; + { NodeStats left_stats_default_left = ComputeNodeStats(root_gradient_stats - right_gradient_stats); @@ -457,7 +462,9 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { best_dimension_idx = dimension_id; } } - { + // Consider calculating the default direction only when there were + // enough missing examples. + if (!fixed_default_direction) { NodeStats left_stats_default_right = ComputeNodeStats(left_gradient_stats); NodeStats right_stats_default_right = diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 9d6cc9245aa463d0c8cfc7ad209736357b6c0323..f06b73c00d0bebb2717a79b7894e2addf914daba 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -501,11 +501,18 @@ def sparse_make_stats_update( example_partition_ids) # Compute aggregate stats for each partition. + # Since unsorted_segment_sum can be numerically unstable, use 64bit + # operation. + gradients64 = math_ops.cast(gradients, dtypes.float64) + hessians64 = math_ops.cast(hessians, dtypes.float64) per_partition_gradients = math_ops.unsorted_segment_sum( - gradients, mapped_partitions, array_ops.size(unique_partitions)) + gradients64, mapped_partitions, array_ops.size(unique_partitions)) per_partition_hessians = math_ops.unsorted_segment_sum( - hessians, mapped_partitions, array_ops.size(unique_partitions)) - + hessians64, mapped_partitions, array_ops.size(unique_partitions)) + per_partition_gradients = math_ops.cast(per_partition_gradients, + dtypes.float32) + per_partition_hessians = math_ops.cast(per_partition_hessians, + dtypes.float32) # Prepend a bias feature per partition that accumulates the stats for all # examples in that partition. bias_feature_ids = array_ops.fill( diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py index 28834ef55bf8e1f32cc8f2380a4be3bf3824d8e1..5cd37ec67ec3bdefb6ea19049a7a12249162d45a 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import random + from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.proto import split_info_pb2 from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops @@ -399,6 +401,65 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): self.assertAllClose(0.6, split_node.split.threshold) + def testMakeSparseSplitDefaultDirectionIsStable(self): + """Tests default direction is stable when no sparsity.""" + random.seed(1123) + for _ in range(50): + with self.test_session() as sess: + grad = random.random() + hessian = random.random() + # The data looks like the following (divide by the num of steps 2). + # Gradients | Partition | bucket ID | + # (grad, hessian) | 0 | -1 | + # And then 100 buckets of + # (grad/100, hessian/100), so there is no sparsity. + n_buckets = 100 + + # 1 for the overall sum, and 100 buckets. + partition_ids = array_ops.constant( + [0] * (n_buckets + 1), dtype=dtypes.int32) + # We have only 1 dimension in our sparse feature column. + + bucket_ids = [-1] + [n for n in range(100)] + bucket_ids = array_ops.constant(bucket_ids, dtype=dtypes.int64) + dimension_ids = array_ops.constant( + [0] * (n_buckets + 1), dtype=dtypes.int64) + bucket_ids = array_ops.stack([bucket_ids, dimension_ids], axis=1) + + gradients = [grad] + [grad / n_buckets] * n_buckets + gradients = array_ops.constant(gradients) + hessians = [hessian] + [hessian / n_buckets] * n_buckets + hessians = array_ops.constant(hessians) + + boundaries = [x * 1 for x in range(n_buckets + 1)] + bucket_boundaries = array_ops.constant(boundaries, dtype=dtypes.float32) + + partitions, gains, splits = ( + split_handler_ops.build_sparse_inequality_splits( + num_minibatches=2, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + bucket_boundaries=bucket_boundaries, + l1_regularization=0, + l2_regularization=2, + tree_complexity_regularization=0, + min_node_weight=0, + feature_column_group_id=0, + bias_feature_id=-1, + class_id=-1, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + partitions, gains, splits = (sess.run([partitions, gains, splits])) + self.assertAllEqual([0], partitions) + self.assertEqual(1, len(splits)) + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + self.assertTrue( + split_info.split_node.HasField( + 'sparse_float_binary_split_default_left')) + def testMakeMulticlassSparseSplit(self): """Tests split handler op.""" with self.test_session() as sess: diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 1b184d296b329cee481db67992e77d1e33e18035..50cc00afdcc77fedc9bf8c94a9a6fcf2a28ebde9 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -187,7 +187,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): stamp_token: Expected current token. next_stamp_token: Next value for the token. Returns: - A list of quantiles or approximate boundaries. + The flush operation. """ return gen_quantile_ops.quantile_accumulator_flush( quantile_accumulator_handle=self._quantile_accumulator_handle, diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index d2c30f121539f8eae5d5f921bd7a1507a81f6e29..af8df72618b7255e182e98e6e4b96a0333b3dce6 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -14,22 +14,29 @@ # ============================================================================== """Tools for working with object-based checkpoints. - -For creating and managing dependencies: -@@CheckpointableObjectGraph +Visualization and inspection: @@dot_graph_from_checkpoint @@object_metadata + +Creating and managing dependencies: +@@Checkpointable +@@CheckpointableObjectGraph +@@NoDependency @@split_dependency +@@UniqueNameTracker """ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph -from tensorflow.python.training.checkpointable_utils import object_metadata +from tensorflow.python.training.checkpointable.base import Checkpointable +from tensorflow.python.training.checkpointable.base import NoDependency +from tensorflow.python.training.checkpointable.util import object_metadata from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD index a5681ffa61d07ef29d0a0862db9736a210c8e26e..53f4e97f9932104933b3ecf80142e5af82cd487a 100644 --- a/tensorflow/contrib/checkpoint/python/BUILD +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -8,11 +8,34 @@ py_library( name = "checkpoint", srcs_version = "PY2AND3", deps = [ + ":containers", ":split_dependency", ":visualize", ], ) +py_library( + name = "containers", + srcs = ["containers.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = ["//tensorflow/python/training/checkpointable:base"], +) + +py_test( + name = "containers_test", + srcs = ["containers_test.py"], + deps = [ + ":containers", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:training", + "//tensorflow/python/training/checkpointable:base", + "@six_archive//:six", + ], +) + py_library( name = "split_dependency", srcs = ["split_dependency.py"], diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py new file mode 100644 index 0000000000000000000000000000000000000000..9807abae1f5106bb84f858c3725f096aaa4eaca9 --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -0,0 +1,77 @@ +"""Checkpointable data structures.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.training.checkpointable import base as checkpointable_lib + + +class UniqueNameTracker(checkpointable_lib.CheckpointableBase): + """Adds dependencies on checkpointable objects with name hints. + + Useful for creating dependencies with locally unique names. + + Example usage: + ```python + class SlotManager(tf.contrib.checkpoint.Checkpointable): + + def __init__(self): + # Create a dependency named "slotdeps" on the container. + self.slotdeps = tf.contrib.checkpoint.UniqueNameTracker() + slotdeps = self.slotdeps + slots = [] + slots.append(slotdeps.track(tfe.Variable(3.), "x")) # Named "x" + slots.append(slotdeps.track(tfe.Variable(4.), "y")) + slots.append(slotdeps.track(tfe.Variable(5.), "x")) # Named "x_1" + ``` + """ + + def __init__(self): + self._maybe_initialize_checkpointable() + self._name_counts = {} + + def track(self, checkpointable, base_name): + """Add a dependency on `checkpointable`. + + Args: + checkpointable: An object to add a checkpoint dependency on. + base_name: A name hint, which is uniquified to determine the dependency + name. + Returns: + `checkpointable`, for chaining. + Raises: + ValueError: If `checkpointable` is not a checkpointable object. + """ + + if not isinstance(checkpointable, checkpointable_lib.CheckpointableBase): + raise ValueError( + ("Expected a checkpointable value, got %s which does not inherit " + "from CheckpointableBase.") % (checkpointable,)) + + def _format_name(prefix, number): + if number > 0: + return "%s_%d" % (prefix, number) + else: + return prefix + + count = self._name_counts.get(base_name, 0) + candidate = _format_name(base_name, count) + while self._lookup_dependency(candidate) is not None: + count += 1 + candidate = _format_name(base_name, count) + self._name_counts[base_name] = count + 1 + return self._track_checkpointable(checkpointable, name=candidate) diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..851a80058852bd917aec075b4bf63264318603a7 --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -0,0 +1,99 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import six + +from tensorflow.contrib.checkpoint.python import containers +from tensorflow.python.framework import test_util +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import util as checkpointable_utils + + +class UniqueNameTrackerTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testNames(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + + x1 = resource_variable_ops.ResourceVariable(2.) + x2 = resource_variable_ops.ResourceVariable(3.) + x3 = resource_variable_ops.ResourceVariable(4.) + y = resource_variable_ops.ResourceVariable(5.) + slots = containers.UniqueNameTracker() + slots.track(x1, "x") + slots.track(x2, "x") + slots.track(x3, "x_1") + slots.track(y, "y") + self.evaluate((x1.initializer, x2.initializer, x3.initializer, + y.initializer)) + save_root = checkpointable_utils.Checkpoint(slots=slots) + save_path = save_root.save(checkpoint_prefix) + + restore_slots = checkpointable.Checkpointable() + restore_root = checkpointable_utils.Checkpoint( + slots=restore_slots) + status = restore_root.restore(save_path) + restore_slots.x = resource_variable_ops.ResourceVariable(0.) + restore_slots.x_1 = resource_variable_ops.ResourceVariable(0.) + restore_slots.x_1_1 = resource_variable_ops.ResourceVariable(0.) + restore_slots.y = resource_variable_ops.ResourceVariable(0.) + status.assert_consumed().run_restore_ops() + self.assertEqual(2., self.evaluate(restore_slots.x)) + self.assertEqual(3., self.evaluate(restore_slots.x_1)) + self.assertEqual(4., self.evaluate(restore_slots.x_1_1)) + self.assertEqual(5., self.evaluate(restore_slots.y)) + + @test_util.run_in_graph_and_eager_modes() + def testExample(self): + class SlotManager(checkpointable.Checkpointable): + + def __init__(self): + self.slotdeps = containers.UniqueNameTracker() + slotdeps = self.slotdeps + slots = [] + slots.append(slotdeps.track( + resource_variable_ops.ResourceVariable(3.), "x")) + slots.append(slotdeps.track( + resource_variable_ops.ResourceVariable(4.), "y")) + slots.append(slotdeps.track( + resource_variable_ops.ResourceVariable(5.), "x")) + self.slots = slots + + manager = SlotManager() + self.evaluate([v.initializer for v in manager.slots]) + checkpoint = checkpointable_utils.Checkpoint(slot_manager=manager) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = checkpoint.save(checkpoint_prefix) + metadata = checkpointable_utils.object_metadata(save_path) + dependency_names = [] + for node in metadata.nodes: + for child in node.children: + dependency_names.append(child.local_name) + six.assertCountEqual( + self, + dependency_names, + ["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"]) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/checkpoint/python/split_dependency.py b/tensorflow/contrib/checkpoint/python/split_dependency.py index 3aec8c96e90440d6da00d95cffc34bd53ec7164f..7e77453f3d848c2e321ed2ba66917a742d95459a 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency.py @@ -20,8 +20,8 @@ from __future__ import print_function import functools from tensorflow.python.ops import control_flow_ops -from tensorflow.python.training import checkpointable as checkpointable from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training.checkpointable import base as checkpointable class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): diff --git a/tensorflow/contrib/checkpoint/python/split_dependency_test.py b/tensorflow/contrib/checkpoint/python/split_dependency_test.py index f1d9d19b047ee69281cf8bdba38a28dc87947e38..69dc0b9be2d5548852c37552a64a0d31c9557b43 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency_test.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency_test.py @@ -23,8 +23,8 @@ from tensorflow.python.eager import test from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training import checkpointable -from tensorflow.python.training import checkpointable_utils +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import util as checkpointable_utils def _split_variable_closure(variable): diff --git a/tensorflow/contrib/checkpoint/python/visualize.py b/tensorflow/contrib/checkpoint/python/visualize.py index 9a3b23bb2c30ee601f5f94da31ad182399a04e4f..bac071c4cff383f60b707b6e42c13faf5e0ac948 100644 --- a/tensorflow/contrib/checkpoint/python/visualize.py +++ b/tensorflow/contrib/checkpoint/python/visualize.py @@ -18,8 +18,8 @@ from __future__ import division from __future__ import print_function from tensorflow.python import pywrap_tensorflow -from tensorflow.python.training import checkpointable -from tensorflow.python.training import checkpointable_utils +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import util as checkpointable_utils def dot_graph_from_checkpoint(save_path): diff --git a/tensorflow/contrib/checkpoint/python/visualize_test.py b/tensorflow/contrib/checkpoint/python/visualize_test.py index 1d9ab789235cb964521315b4864563f89745ae75..583e3bc442893d825c337d73fb999d1e586738a1 100644 --- a/tensorflow/contrib/checkpoint/python/visualize_test.py +++ b/tensorflow/contrib/checkpoint/python/visualize_test.py @@ -24,11 +24,11 @@ from tensorflow.contrib.checkpoint.python import visualize from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op -from tensorflow.python.keras._impl.keras.engine import training -from tensorflow.python.keras._impl.keras.layers import core +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core from tensorflow.python.ops import resource_variable_ops from tensorflow.python.training import adam -from tensorflow.python.training import checkpointable_utils +from tensorflow.python.training.checkpointable import util as checkpointable_utils try: import pydot # pylint: disable=g-import-not-at-top diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 1403483d287041b02dfbf538f7e7ddee11662f47..880fca4ea65608472838baee234e468bef37afb3 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -36,6 +36,8 @@ except ImportError: _GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' +_DEFAULT_ENV_VARIABLE = 'TPU_NAME' +_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL' class TPUClusterResolver(ClusterResolver): @@ -70,6 +72,16 @@ class TPUClusterResolver(ClusterResolver): def _gkeMaster(): return os.environ[_GKE_ENV_VARIABLE].split(',')[0] + @staticmethod + def _envVarFallback(): + if _DEFAULT_ENV_VARIABLE in os.environ: + return os.environ[_DEFAULT_ENV_VARIABLE] + return None + + @staticmethod + def _discoveryUrl(): + return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE) + def __init__(self, tpu=None, zone=None, @@ -78,7 +90,8 @@ class TPUClusterResolver(ClusterResolver): coordinator_name=None, coordinator_address=None, credentials='default', - service=None): + service=None, + discovery_url=None): """Creates a new TPUClusterResolver object. The ClusterResolver will then use the parameters to query the Cloud TPU APIs @@ -108,6 +121,11 @@ class TPUClusterResolver(ClusterResolver): service: The GCE API object returned by the googleapiclient.discovery function. If you specify a custom service object, then the credentials parameter will be ignored. + discovery_url: A URL template that points to the location of + the discovery service. It should have two parameters {api} and + {apiVersion} that when filled in produce an absolute URL to the + discovery document for that service. The environment variable + 'TPU_API_DISCOVERY_URL' will override this. Raises: ImportError: If the googleapiclient is not installed. @@ -123,8 +141,11 @@ class TPUClusterResolver(ClusterResolver): in_gke = self._inGke() # When using GKE with Cloud TPUs, the env variable will be set. - if tpu is None and in_gke: - tpu = self._gkeMaster() + if tpu is None: + if in_gke: + tpu = self._gkeMaster() + else: + tpu = self._envVarFallback() self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes self._job_name = job_name @@ -154,9 +175,16 @@ class TPUClusterResolver(ClusterResolver): '--upgrade google-api-python-client` to install with ' 'pip.') - self._service = discovery.build( - 'tpu', 'v1alpha1', - credentials=self._credentials) + final_discovery_url = self._discoveryUrl() or discovery_url + if final_discovery_url: + self._service = discovery.build( + 'tpu', 'v1alpha1', + credentials=self._credentials, + discoveryServiceUrl=final_discovery_url) + else: + self._service = discovery.build( + 'tpu', 'v1alpha1', + credentials=self._credentials) else: self._service = service diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 5b3f9be5a11237f9dceebefa1db294efaf7e482d..5fac55fd027fa2d100621e08a09e05cdb3a1b941 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -367,6 +367,10 @@ class TPUClusterResolverTest(test.TestCase): compat.as_bytes(TPUClusterResolver._gkeMaster())) del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] + def testDiscoveryUrl(self): + os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}' + self.assertEqual('https://{api}.internal/{apiVersion}', + TPUClusterResolver._discoveryUrl()) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 6468bed4979253be5c20666d26bf24fa479d64a0..fece56c4127de4deebc1404f0eff9747f99ba89f 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -32,52 +32,13 @@ tensorflow/python/feature_column tensorflow/python/framework tensorflow/python/grappler tensorflow/python/keras -tensorflow/python/keras/activations tensorflow/python/keras/applications -tensorflow/python/keras/applications/densenet -tensorflow/python/keras/applications/inception_resnet_v2 -tensorflow/python/keras/applications/inception_v3 -tensorflow/python/keras/applications/mobilenet -tensorflow/python/keras/applications/nasnet -tensorflow/python/keras/applications/resnet50 -tensorflow/python/keras/applications/vgg16 -tensorflow/python/keras/applications/vgg19 -tensorflow/python/keras/applications/xception -tensorflow/python/keras/backend -tensorflow/python/keras/callbacks -tensorflow/python/keras/constraints tensorflow/python/keras/datasets -tensorflow/python/keras/datasets/boston_housing -tensorflow/python/keras/datasets/cifar10 -tensorflow/python/keras/datasets/cifar100 -tensorflow/python/keras/datasets/fashion_mnist -tensorflow/python/keras/datasets/imdb -tensorflow/python/keras/datasets/mnist -tensorflow/python/keras/datasets/reuters -tensorflow/python/keras/estimator -tensorflow/python/keras/initializers +tensorflow/python/keras/engine tensorflow/python/keras/layers -tensorflow/python/keras/losses -tensorflow/python/keras/metrics -tensorflow/python/keras/models -tensorflow/python/keras/optimizers tensorflow/python/keras/preprocessing -tensorflow/python/keras/preprocessing/image -tensorflow/python/keras/preprocessing/sequence -tensorflow/python/keras/preprocessing/text -tensorflow/python/keras/regularizers tensorflow/python/keras/utils tensorflow/python/keras/wrappers -tensorflow/python/keras/wrappers/scikit_learn -tensorflow/python/keras/_impl -tensorflow/python/keras/_impl/keras -tensorflow/python/keras/_impl/keras/applications -tensorflow/python/keras/_impl/keras/datasets -tensorflow/python/keras/_impl/keras/engine -tensorflow/python/keras/_impl/keras/layers -tensorflow/python/keras/_impl/keras/preprocessing -tensorflow/python/keras/_impl/keras/utils -tensorflow/python/keras/_impl/keras/wrappers tensorflow/python/kernel_tests tensorflow/python/kernel_tests/boosted_trees tensorflow/python/kernel_tests/distributions @@ -100,6 +61,7 @@ tensorflow/python/summary tensorflow/python/summary/writer tensorflow/python/tools tensorflow/python/training +tensorflow/python/training/checkpointable tensorflow/python/user_ops tensorflow/python/util tensorflow/python/util/protobuf @@ -333,6 +295,8 @@ tensorflow/contrib/metrics tensorflow/contrib/metrics/python tensorflow/contrib/metrics/python/metrics tensorflow/contrib/metrics/python/ops +tensorflow/contrib/mixed_precision +tensorflow/contrib/mixed_precision/python tensorflow/contrib/mpi_collectives/python tensorflow/contrib/mpi_collectives/python/ops tensorflow/contrib/model_pruning diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt index d63c41db844af243f0c6600b1565635ac9b91cac..cf1ee2ad76f2cc9f58dbe90182a3e17f1edc7ed3 100644 --- a/tensorflow/contrib/cmake/python_protos.txt +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -11,7 +11,6 @@ tensorflow/contrib/mpi tensorflow/contrib/mpi_collectives tensorflow/contrib/session_bundle tensorflow/contrib/tensor_forest/proto -tensorflow/contrib/tensorboard/graph_explorer/proto tensorflow/contrib/tensorboard/plugins/projector tensorflow/contrib/tensorboard/plugins/trace tensorflow/contrib/tpu/proto diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index c6a15f2ca075c8de96786a580c7ddb89541df5bc..a06bdf78fb011b288d5d7af6488ec6802ff34c35 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -22,8 +22,6 @@ set(tf_c_srcs "${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.h" "${tensorflow_source_dir}/tensorflow/c/eager/tape.h" - "${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc" - "${tensorflow_source_dir}/tensorflow/c/eager/runtime.h" "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc" "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.h" "${tensorflow_source_dir}/tensorflow/c/tf_status_helper.cc" @@ -38,13 +36,15 @@ add_dependencies( tf_core_lib tf_protos_cc) -add_library(tf_c_python_api OBJECT - "${tensorflow_source_dir}/tensorflow/c/python_api.cc" - "${tensorflow_source_dir}/tensorflow/c/python_api.h" -) -add_dependencies( - tf_c_python_api - tf_c - tf_core_lib - tf_core_framework - tf_protos_cc) +if(tensorflow_BUILD_PYTHON_BINDINGS) + add_library(tf_c_python_api OBJECT + "${tensorflow_source_dir}/tensorflow/c/python_api.cc" + "${tensorflow_source_dir}/tensorflow/c/python_api.h" + ) + add_dependencies( + tf_c_python_api + tf_c + tf_core_lib + tf_core_framework + tf_protos_cc) +endif() diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake index f73da0b8ab18af1eca4c2bd577604595f8b8ec6d..6c90cf398c69c8c1b22ea75e0c407f258e2535f9 100644 --- a/tensorflow/contrib/cmake/tf_cc_ops.cmake +++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake @@ -155,7 +155,7 @@ if (WIN32) set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.lib") endif() else (WIN32) - set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal.so") + set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal${CMAKE_SHARED_LIBRARY_SUFFIX}") endif (WIN32) add_custom_target(tf_extension_ops) diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 1505d3e2083b5a3446a7f85d59c73816e65e1a2a..2d76bf530a2100b2afa80a16a5d64b6ec51ffc68 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -68,6 +68,7 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/csv_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index c4bdb69d828b269e6246777e74c3756ba1c4b96f..894b1ead7688bd951c5c78f613fdb7aae226fe65 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -244,13 +244,11 @@ add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD # tf_python_op_gen_main library ######################################################## set(tf_python_op_gen_main_srcs - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.h" - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" - "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" - "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc" ) add_library(tf_python_op_gen_main OBJECT ${tf_python_op_gen_main_srcs}) @@ -464,12 +462,12 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe_src.cc" "${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.h" "${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.cc" - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.h" - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.cc" "${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.h" "${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/numpy.h" @@ -715,7 +713,7 @@ if(WIN32) endif() else() add_custom_command(TARGET pywrap_tensorflow_internal POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal.so + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal${CMAKE_SHARED_LIBRARY_SUFFIX} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.so) endif() @@ -791,7 +789,6 @@ add_custom_command(TARGET tf_python_build_pip_package POST_BUILD add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/contrib/testing/python/framework/util_test.py ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/testing/python/framework/) - add_custom_command(TARGET tf_python_build_pip_package POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/tools/pip_package/README ${CMAKE_CURRENT_BINARY_DIR}/tf_python/) diff --git a/tensorflow/contrib/cmake/tools/create_def_file.py b/tensorflow/contrib/cmake/tools/create_def_file.py index cffe069aa352f8a6f2c436bc70b62f54e2336ac6..4f957f1e0b46fde5daacbc59657af994e13c42d5 100644 --- a/tensorflow/contrib/cmake/tools/create_def_file.py +++ b/tensorflow/contrib/cmake/tools/create_def_file.py @@ -44,7 +44,8 @@ UNDNAME = "undname.exe" DUMPBIN = "dumpbin.exe" # Exclude if matched -EXCLUDE_RE = re.compile(r"RTTI|deleting destructor|::internal::") +EXCLUDE_RE = re.compile(r"RTTI|deleting destructor|::internal::|Internal|" + r"python_op_gen_internal|grappler") # Include if matched before exclude INCLUDEPRE_RE = re.compile(r"google::protobuf::internal::ExplicitlyConstructed|" @@ -56,6 +57,10 @@ INCLUDEPRE_RE = re.compile(r"google::protobuf::internal::ExplicitlyConstructed|" r"tensorflow::ops::internal::Enter|" r"tensorflow::strings::internal::AppendPieces|" r"tensorflow::strings::internal::CatPieces|" + r"tensorflow::errors::Internal|" + r"tensorflow::Tensor::CopyFromInternal|" + r"tensorflow::kernel_factory::" + r"OpKernelRegistrar::InitInternal|" r"tensorflow::io::internal::JoinPathImpl") # Include if matched after exclude @@ -64,7 +69,7 @@ INCLUDE_RE = re.compile(r"^(TF_\w*)$|" r"tensorflow::|" r"functor::|" r"\?nsync_|" - r"perftools::gputools") + r"stream_executor::") # We want to identify data members explicitly in the DEF file, so that no one # can implicitly link against the DLL if they use one of the variables exported diff --git a/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc b/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc index ae4d9d2836a0f89a9765004a85bc3c292b0e484f..81b36ca902b82220d9c5282a1ec72324a6d95922 100644 --- a/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc +++ b/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op.h" diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck.py b/tensorflow/contrib/coder/python/layers/entropybottleneck.py index f039cb0f5265b920200f63c5bd5ebeb4e23826be..0fbe3081af0b4de7f116918b3f49efe91a2d83bd 100644 --- a/tensorflow/contrib/coder/python/layers/entropybottleneck.py +++ b/tensorflow/contrib/coder/python/layers/entropybottleneck.py @@ -28,7 +28,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras._impl.keras import engine +from tensorflow.python.keras import engine from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import init_ops diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py index b2f678fb29cedd3ec32f0460354cc4ac18fb63d3..a56a01b16356e12b83344474c7fbe427530f0c74 100644 --- a/tensorflow/contrib/compiler/jit_test.py +++ b/tensorflow/contrib/compiler/jit_test.py @@ -24,7 +24,6 @@ from tensorflow.python.framework import function from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed -from tensorflow.python.framework import test_util from tensorflow.python.ops import gradients from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -170,7 +169,6 @@ class JITTest(test.TestCase): self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s) -@test_util.with_c_api class CompilationEnabledInGradientTest(test.TestCase): def testCompilationInGradient(self): diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index 33ddfb8dee1c446f22c7d0071f9a0e2bbac6bdad..8285ea04926d3a24e9c22bd6d69eb7a48f5e3a85 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -54,11 +54,11 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import adagrad from tensorflow.python.training import adam -from tensorflow.python.training import checkpointable_utils from tensorflow.python.training import gradient_descent from tensorflow.python.training import momentum from tensorflow.python.training import rmsprop from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training.checkpointable import util as checkpointable_utils CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 73a961992e19fabec5d0f75be1b52dbba20eb7af..ed0a26bbd87eeb5bd005de8f9d054d315e378529 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -24,7 +24,7 @@ from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed -from tensorflow.python.keras._impl.keras.engine import base_layer +from tensorflow.python.keras.engine import base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_cudnn_rnn_ops from tensorflow.python.ops import init_ops @@ -33,8 +33,8 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.training import checkpointable as checkpointable_lib from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import base as checkpointable_lib CUDNN_RNN_UNIDIRECTION = "unidirectional" CUDNN_RNN_BIDIRECTION = "bidirectional" diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 077cbba9d2ae41a83f6c358a63ae27aec5741e2c..a25aa85251083c24ca6685c4ffef267955f66f63 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -23,6 +23,8 @@ removing existing functionality. See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@Counter +@@CheckpointInputPipelineHook +@@CsvDataset @@SqlDataset @@assert_element_shape @@ -72,8 +74,10 @@ from tensorflow.contrib.data.python.ops.grouping import group_by_window from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datasets from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave +from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device +from tensorflow.contrib.data.python.ops.readers import CsvDataset from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset from tensorflow.contrib.data.python.ops.readers import make_csv_dataset from tensorflow.contrib.data.python.ops.readers import read_batch_features diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD index c56910c7833d4c54fa8db27cd061b404013f3f54..7b69e10441eba3e38c979d5715c16699ac2710ed 100644 --- a/tensorflow/contrib/data/kernels/BUILD +++ b/tensorflow/contrib/data/kernels/BUILD @@ -29,6 +29,16 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "csv_dataset_op", + srcs = ["csv_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], +) + cc_library( name = "ignore_errors_dataset_op", srcs = ["ignore_errors_dataset_op.cc"], @@ -63,6 +73,7 @@ cc_library( cc_library( name = "dataset_kernels", deps = [ + ":csv_dataset_op", ":directed_interleave_dataset_op", ":ignore_errors_dataset_op", ":prefetching_kernels", diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..76e54a284e07ec1bab9b0f364a44997a39bce78a --- /dev/null +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -0,0 +1,508 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/parsing_ops.cc. +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/io/buffered_inputstream.h" +#include "tensorflow/core/lib/io/random_inputstream.h" + +namespace tensorflow { +namespace { + +class CSVDatasetOp : public DatasetOpKernel { + public: + explicit CSVDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + const Tensor* filenames_tensor; + OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); + OP_REQUIRES( + ctx, filenames_tensor->dims() <= 1, + errors::InvalidArgument("`filenames` must be a scalar or a vector.")); + + OpInputList record_defaults_list; + OP_REQUIRES_OK(ctx, + ctx->input_list("record_defaults", &record_defaults_list)); + for (int i = 0; i < record_defaults_list.size(); ++i) { + OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2, + errors::InvalidArgument( + "There should only be 1 default per field but field ", i, + " has ", record_defaults_list[i].NumElements())); + } + + const Tensor* select_cols_tensor; + OP_REQUIRES_OK(ctx, ctx->input("select_cols", &select_cols_tensor)); + OP_REQUIRES(ctx, select_cols_tensor->dims() == 1, + errors::InvalidArgument("`select_cols` must be a vector.")); + + int64 buffer_size; + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "buffer_size", &buffer_size)); + OP_REQUIRES(ctx, buffer_size > 0, + errors::InvalidArgument("buffer_size should be positive")); + + string delim; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "field_delim", &delim)); + OP_REQUIRES(ctx, delim.size() == 1, + errors::InvalidArgument("field_delim should be only 1 char")); + + bool header; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "header", &header)); + + bool use_quote_delim; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "use_quote_delim", + &use_quote_delim)); + string na_value; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "na_value", &na_value)); + + std::vector record_defaults; + record_defaults.reserve(record_defaults_list.size()); + for (const Tensor& t : record_defaults_list) { + record_defaults.push_back(t); + } + + std::vector filenames; + filenames.reserve(filenames_tensor->NumElements()); + for (int i = 0; i < filenames_tensor->NumElements(); ++i) { + filenames.push_back(filenames_tensor->flat()(i)); + } + + std::vector select_cols; + select_cols.reserve(select_cols_tensor->NumElements()); + for (int i = 0; i < select_cols_tensor->NumElements(); ++i) { + select_cols.push_back(select_cols_tensor->flat()(i)); + } + OP_REQUIRES( + ctx, output_types_.size() == select_cols.size() || select_cols.empty(), + errors::InvalidArgument("select_cols should match output size")); + for (int i = 1; i < select_cols.size(); i++) { + OP_REQUIRES(ctx, select_cols[i - 1] < select_cols[i], + errors::InvalidArgument( + "select_cols should be strictly increasing indices")); + } + OP_REQUIRES( + ctx, select_cols.empty() || select_cols.front() >= 0, + errors::InvalidArgument("select_cols should be non-negative indices")); + bool select_all_cols = select_cols.empty(); + + *output = new Dataset( + ctx, std::move(filenames), header, buffer_size, output_types_, + output_shapes_, std::move(record_defaults), std::move(select_cols), + select_all_cols, use_quote_delim, delim[0], std::move(na_value)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, std::vector filenames, bool header, + int64 buffer_size, const DataTypeVector& output_types, + const std::vector& output_shapes, + std::vector record_defaults, std::vector select_cols, + bool select_all_cols, bool use_quote_delim, char delim, + string na_value) + : GraphDatasetBase(ctx), + filenames_(std::move(filenames)), + header_(header), + buffer_size_(buffer_size), + out_type_(output_types), + output_shapes_(output_shapes), + record_defaults_(std::move(record_defaults)), + select_cols_(std::move(select_cols)), + select_all_cols_(select_all_cols), + use_quote_delim_(use_quote_delim), + delim_(delim), + na_value_(std::move(na_value)) {} + + std::unique_ptr MakeIterator( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::CSV")})); + } + + const DataTypeVector& output_dtypes() const override { return out_type_; } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() override { return "CSVDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Node** output) const override { + // TODO(rachelim): Implement this + std::vector input_tensors; + TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output)); + return errors::Unimplemented("CSVDataset: AsGraphDefInternal"); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + do { + // We are currently processing a file, so try to read the next record + if (buffered_input_stream_) { + Status s = ReadRecord(ctx, out_tensors); + if (s.ok() || !errors::IsOutOfRange(s)) { + // Not at the end of file, return OK or non-EOF errors to caller. + *end_of_sequence = false; + return s; + } + // We have reached the end of the current file, so maybe + // move on to next file. + ResetStreamsLocked(); + ++current_file_index_; + } + // Iteration ends when there are no more files to process. + if (current_file_index_ == dataset()->filenames_.size()) { + *end_of_sequence = true; + return Status::OK(); + } + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + } while (true); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + // TODO(rachelim): Implement save + return errors::Unimplemented("CSVDataset: SaveInternal"); + } + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + // TODO(rachelim): Implement restore + return errors::Unimplemented("CSVDataset: RestoreInternal"); + } + + private: + // Reads a record by parsing the input buffer, and converting extracted + // fields to output tensors as we go. + Status ReadRecord(IteratorContext* ctx, std::vector* out_tensors) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + // Extracts fields from line(s) from the buffered input stream. + out_tensors->reserve(dataset()->record_defaults_.size()); + + string input; + TF_RETURN_IF_ERROR(buffered_input_stream_->ReadLine(&input)); + + size_t current_idx = 0; + size_t num_fields_parsed = 0; + size_t selector_idx = 0; // Keep track of index into select_cols + + while (current_idx < input.size()) { + // In each iteration, parse one field + if (input[current_idx] == '\n' || input[current_idx] == '\r') { + // This should never happen, because buffered input reader splits + // input on newlines. + return errors::InvalidArgument("Parsing error."); + } + + bool quoted = false; + bool include = + (dataset()->select_all_cols_ || + dataset()->select_cols_[selector_idx] == num_fields_parsed); + + if (dataset()->use_quote_delim_ && input[current_idx] == '"') { + quoted = true; + current_idx++; + } + + // Parse the body of the field + string field; + if (!quoted) { + while (current_idx < input.size() && + input[current_idx] != dataset()->delim_) { + if ((dataset()->use_quote_delim_ && input[current_idx] == '"') || + input[current_idx] == '\n' || input[current_idx] == '\r') { + return errors::InvalidArgument( + "Unquoted fields cannot have quotes/CRLFs inside"); + } + if (include) field += input[current_idx]; + current_idx++; + } // Exit condition: end of input, or current index at delim + + // Go to next field or the end + current_idx++; + } else { + // Quoted field needs to be ended with '"' and delim or end + while (true) { + if (current_idx >= input.size() - 1 || input.empty()) { + if (current_idx == input.size() - 1 && + input[current_idx] == '"') { + // We're at the end of the input, and the quote terminates the + // record. Go to end. + current_idx++; + break; + } + // If there's no terminating quote, it means our buffered record + // line reader split a record up. This can happen if there is a + // newline encased in quotes. The next line is also part of the + // record, so we read it and reset the index. + if (include && current_idx == input.size() - 1) { + // TODO(rachelim): Instead of building up a string, keep track + // of terminal indices (or starting char* and length) + // Also look into using /lib/strings/Scanner + field += input[current_idx]; + } + if (include) { + field += '\n'; + } + current_idx = 0; + Status s = buffered_input_stream_->ReadLine(&input); + if (!s.ok()) { + return errors::InvalidArgument( + "Quoted field has to end with quote followed by delim, " + "CRLF, or EOF"); + } + } else if (input[current_idx] == '"' && + input[current_idx + 1] == dataset()->delim_) { + // End of field, go to next field or end + current_idx += 2; + break; + } else if (input[current_idx] == '"') { + // Current char is a quote. Since we're not at end of field, + // the next character must also be a quote. + if (input[current_idx + 1] != '"') { + return errors::InvalidArgument( + "Quote inside a string has to be escaped by another " + "quote"); + } + if (include) field += '"'; + current_idx += 2; + } else { + if (include) field += input[current_idx]; + current_idx++; + } + } + } + + num_fields_parsed++; + + if (include) { + // Add the tensor to the result + TF_RETURN_IF_ERROR(FieldToOutput(ctx, std::move(field), + selector_idx, out_tensors)); + selector_idx++; + // Terminate early if we have all the fields we want + if (selector_idx == dataset()->select_cols_.size()) + return Status::OK(); + } + } // Exit condition: current_idx has reached the end of record + + // Check if the last field is empty, and include it if necessary + bool include = + (dataset()->select_all_cols_ || + dataset()->select_cols_[selector_idx] == num_fields_parsed); + if (include && !input.empty() && + input[input.size() - 1] == dataset()->delim_) { + TF_RETURN_IF_ERROR( + FieldToOutput(ctx, string(), selector_idx, out_tensors)); + } + + // Check that number of fields matches + if (out_tensors->size() != dataset()->out_type_.size()) { + return errors::InvalidArgument("Expect ", dataset()->out_type_.size(), + " fields but have ", + out_tensors->size(), " in record"); + } + return Status::OK(); + } + + // Given a string field, and its index in the output, + // converts it to a Tensor of the right type and adds it to the + // out_tensors vector. + Status FieldToOutput(IteratorContext* ctx, string field, + size_t output_idx, + std::vector* out_tensors) { + if (output_idx >= dataset()->out_type_.size()) { + // We can get here if we're selecting all columns, but the number of + // fields exceeds the number of defaults provided + return errors::InvalidArgument("Expect ", dataset()->out_type_.size(), + " fields but have more in record"); + } + const DataType& dtype = dataset()->out_type_[output_idx]; + Tensor component(ctx->allocator({}), dtype, {}); + if ((field.empty() || field == dataset()->na_value_) && + dataset()->record_defaults_[output_idx].NumElements() != 1) { + // If the field is empty or NA value, and default is not given, + // report error. + return errors::InvalidArgument("Field ", output_idx, + " is required but missing in record!"); + } + + switch (dtype) { + // For each case, if the field is empty, we use the default. + // Otherwise, we convert it to the right type. + case DT_INT32: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + int32 value; + if (!strings::safe_strto32(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid int32: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_INT64: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + int64 value; + if (!strings::safe_strto64(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid int64: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_FLOAT: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + float value; + if (!strings::safe_strtof(field.c_str(), &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid float: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_DOUBLE: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + double value; + if (!strings::safe_strtod(field.c_str(), &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid double: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_STRING: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + component.scalar()() = std::move(field); + } + break; + } + default: + return errors::InvalidArgument("csv: data type ", dtype, + " not supported in field ", + output_idx); + } + out_tensors->push_back(std::move(component)); + return Status::OK(); + } + + // Sets up reader streams to read from the file at `current_file_index_`. + Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (current_file_index_ >= dataset()->filenames_.size()) { + return errors::InvalidArgument( + "current_file_index_:", current_file_index_, + " >= filenames_.size():", dataset()->filenames_.size()); + } + + // Actually move on to next file. + TF_RETURN_IF_ERROR(env->NewRandomAccessFile( + dataset()->filenames_[current_file_index_], &file_)); + input_stream_.reset( + new io::RandomAccessInputStream(file_.get(), false)); + // TODO(rachelim): Maintain our own buffer so we don't read every record + // twice + buffered_input_stream_.reset(new io::BufferedInputStream( + input_stream_.get(), dataset()->buffer_size_, false)); + if (dataset()->header_) { + // Ignore header line + string str; + Status s = buffered_input_stream_->ReadLine(&str); + if (errors::IsOutOfRange(s)) { + return errors::InvalidArgument("Can't read header of empty file"); + } + } + return Status::OK(); + } + + // Resets all reader streams. + void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + input_stream_.reset(); + buffered_input_stream_.reset(); + file_.reset(); + } + + mutex mu_; + std::unique_ptr input_stream_ + GUARDED_BY(mu_); + std::unique_ptr buffered_input_stream_ + GUARDED_BY(mu_); + size_t current_file_index_ GUARDED_BY(mu_) = 0; + std::unique_ptr file_ + GUARDED_BY(mu_); // must outlive input_stream_ + }; // class Iterator + + const std::vector filenames_; + const bool header_; + const int64 buffer_size_; + const DataTypeVector out_type_; + const std::vector output_shapes_; + const std::vector record_defaults_; + const std::vector select_cols_; + const bool select_all_cols_; + const bool use_quote_delim_; + const char delim_; + const string na_value_; + }; // class Dataset + + DataTypeVector output_types_; + std::vector output_shapes_; +}; // class CSVDatasetOp + +// Register the kernel implementation for CSVDataset. +REGISTER_KERNEL_BUILDER(Name("CSVDataset").Device(DEVICE_CPU), CSVDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index 137deb63527f0bdde7da8d5be83ed038f430e581..f271d269ab1b9339de4657e459dcbbd462890f0a 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -34,6 +34,40 @@ data_input_datasets: `N` datasets with the same type that will be interleaved according to the values of `selector_input_dataset`. )doc"); +REGISTER_OP("CSVDataset") + .Input("filenames: string") + .Input("buffer_size: int64") + .Input("header: bool") + .Input("field_delim: string") + .Input("use_quote_delim: bool") + .Input("na_value: string") + .Input("select_cols: int64") + .Input("record_defaults: output_types") + .Output("handle: variant") + .Attr("output_types: list({float,double,int32,int64,string}) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // `filenames` must be a scalar or a vector. + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); + // `buffer_size`, `header`, `field_delim`, `use_quote_delim`, + // `na_value` must be scalars + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + // `select_cols` must be a vector + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &unused)); + // `record_defaults` must be a list of scalars...? + for (size_t i = 7; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &unused)); + } + return shape_inference::ScalarShape(c); + }); + REGISTER_OP("IgnoreErrorsDataset") .Input("input_dataset: variant") .Output("handle: variant") diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 6017e27e731e3e8bcdee516ea291b17cd0782e63..f5082228e885d065e659abf208ca7b94bb4999a5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -11,7 +11,10 @@ py_test( size = "medium", srcs = ["batch_dataset_op_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + tags = [ + "no_oss", # (b/79552534) + "no_pip", + ], deps = [ ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:batching", @@ -117,6 +120,19 @@ py_library( ], ) +py_test( + name = "csv_dataset_op_test", + size = "small", + srcs = ["csv_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:readers", + "//third_party/py/numpy", + ], +) + py_test( name = "filter_dataset_op_test", size = "small", @@ -287,6 +303,7 @@ py_test( name = "reader_dataset_ops_test", size = "medium", srcs = ["reader_dataset_ops_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ @@ -301,6 +318,7 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", + "//tensorflow/python:string_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", @@ -411,6 +429,7 @@ py_test( srcs = ["sql_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 6588fd04acb02790f5002c2e983253c3ed0504cf..2568b899d7ea1be685036ad8af93f584f861c951 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -427,7 +427,9 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def _testMapAndBatchDatasetHelper(self, num_parallel_batches=1): + def _testMapAndBatchDatasetHelper(self, + num_parallel_calls=None, + num_parallel_batches=None): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). @@ -446,6 +448,7 @@ class BatchDatasetTest(test.TestCase): batching.map_and_batch( map_func=_map_fn, batch_size=batch_size, + num_parallel_calls=num_parallel_calls, num_parallel_batches=num_parallel_batches)) .make_initializable_iterator()) init_op = iterator.initializer @@ -497,12 +500,18 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) - def testMapAndBatchDataset(self): + def testMapAndBatch(self): return self._testMapAndBatchDatasetHelper() - def testMapAndBatchDatasetWithParallelBatching(self): + def testMapAndBatchWithParallelBatches(self): return self._testMapAndBatchDatasetHelper(num_parallel_batches=10) + def testMapAndBatchWithSequentialCalls(self): + return self._testMapAndBatchDatasetHelper(num_parallel_calls=1) + + def testMapAndBatchWithParallelCalls(self): + return self._testMapAndBatchDatasetHelper(num_parallel_calls=2) + def _testMapAndBatchPartialBatchHelper(self, drop_remainder=False): iterator = ( dataset_ops.Dataset.range(10).apply( @@ -682,7 +691,7 @@ class UnbatchDatasetSerializationTest( class MapAndBatchDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): - def testSerializationCore(self): + def testNumParallelBatches(self): range_size = 11 num_repeats = 2 batch_size = 5 @@ -709,6 +718,33 @@ class MapAndBatchDatasetSerializationTest( self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), num_outputs_drop_remainder) + def testNumParallelCalls(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_calls = 7 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_calls=num_parallel_calls, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + class PaddedBatchDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..641a389c033687ebe081963182390b00230e4cb5 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -0,0 +1,378 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CsvDatasetOp.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile +import time + +import numpy as np + +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.client import session +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_parsing_ops +from tensorflow.python.platform import gfile +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test + + +class CsvDatasetOpTest(test.TestCase): + + def _assert_datasets_equal(self, g, ds1, ds2): + assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, ' + '%s') % (ds1.output_shapes, + ds2.output_shapes) + assert ds1.output_types == ds2.output_types + assert ds1.output_classes == ds2.output_classes + next1 = ds1.make_one_shot_iterator().get_next() + next2 = ds2.make_one_shot_iterator().get_next() + with self.test_session(graph=g) as sess: + # Run through datasets and check that outputs match, or errors match. + while True: + try: + op1 = sess.run(next1) + except (errors.OutOfRangeError, ValueError) as e: + # If op1 throws an exception, check that op2 throws same exception. + with self.assertRaises(type(e)): + sess.run(next2) + break + op2 = sess.run(next2) + self.assertAllEqual(op1, op2) + + def setup_files(self, inputs): + filenames = [] + for i, ip in enumerate(inputs): + fn = os.path.join(self.get_temp_dir(), 'temp_%d.txt' % i) + with open(fn, 'w') as f: + f.write('\n'.join(ip)) + filenames.append(fn) + return filenames + + def _make_test_datasets(self, inputs, **kwargs): + # Test by comparing its output to what we could get with map->decode_csv + filenames = self.setup_files(inputs) + dataset_expected = core_readers.TextLineDataset(filenames) + dataset_expected = dataset_expected.map( + lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) + dataset_actual = readers.CsvDataset(filenames, **kwargs) + return (dataset_actual, dataset_expected) + + def _test_by_comparison(self, inputs, **kwargs): + """Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv).""" + with ops.Graph().as_default() as g: + dataset_actual, dataset_expected = self._make_test_datasets( + inputs, **kwargs) + self._assert_datasets_equal(g, dataset_actual, dataset_expected) + + def _test_dataset(self, + inputs, + expected_output=None, + expected_err_re=None, + **kwargs): + """Checks that elements produced by CsvDataset match expected output.""" + # Convert str type because py3 tf strings are bytestrings + filenames = self.setup_files(inputs) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.CsvDataset(filenames, **kwargs) + nxt = dataset.make_one_shot_iterator().get_next() + if expected_err_re is None: + # Verify that output is expected, without errors + expected_output = [[ + v.encode('utf-8') if isinstance(v, str) else v for v in op + ] for op in expected_output] + for value in expected_output: + op = sess.run(nxt) + self.assertAllEqual(op, value) + with self.assertRaises(errors.OutOfRangeError): + sess.run(nxt) + else: + # Verify that OpError is produced as expected + with self.assertRaisesOpError(expected_err_re): + while True: + try: + sess.run(nxt) + except errors.OutOfRangeError: + break + + def testCsvDataset_floatRequired(self): + record_defaults = [[]] * 4 + inputs = [['1,2,3,4']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_int(self): + record_defaults = [[0]] * 4 + inputs = [['1,2,3,4', '5,6,7,8']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_float(self): + record_defaults = [[0.0]] * 4 + inputs = [['1.0,2.1,3.2,4.3', '5.4,6.5,7.6,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_string(self): + record_defaults = [['']] * 4 + inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withQuoted(self): + record_defaults = [['']] * 4 + inputs = [['1.0,2.1,"hello, it is me",4.3', '5.4,6.5,goodbye,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_mixedTypes(self): + record_defaults = [ + constant_op.constant([], dtype=dtypes.int32), + constant_op.constant([], dtype=dtypes.float32), + constant_op.constant([], dtype=dtypes.string), + constant_op.constant([], dtype=dtypes.float64) + ] + inputs = [['1,2.1,3.2,4.3', '5,6.5,7.6,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withUseQuoteDelimFalse(self): + record_defaults = [['']] * 4 + inputs = [['1,2,"3,4"', '"5,6",7,8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, use_quote_delim=False) + + def testCsvDataset_withFieldDelim(self): + record_defaults = [[0]] * 4 + inputs = [['1:2:3:4', '5:6:7:8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, field_delim=':') + + def testCsvDataset_withEmptyValues(self): + record_defaults = [[0]] * 4 + inputs = [['1,,3,4', ',6,7,8']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withNaValue(self): + record_defaults = [[0]] * 4 + inputs = [['1,NA,3,4', 'NA,6,7,8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, na_value='NA') + + def testCsvDataset_withSelectCols(self): + record_defaults = [[0]] * 2 + inputs = [['1,2,3,4', '5,6,7,8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, select_cols=[1, 2]) + + def testCsvDataset_withSelectColsTooHigh(self): + record_defaults = [[0]] * 2 + inputs = [['1,2,3,4', '5,6,7,8']] + self._test_dataset( + inputs, + expected_err_re='Expect 2 fields but have 1 in record', + record_defaults=record_defaults, + select_cols=[3, 4]) + + def testCsvDataset_withMultipleFiles(self): + record_defaults = [[0]] * 4 + inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withNewLine(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_withMultipleNewLines(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_withLeadingAndTrailingSpaces(self): + record_defaults = [[0.0]] * 4 + inputs = [['0, 1, 2, 3']] + expected = [[0.0, 1.0, 2.0, 3.0]] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_errorWithMissingDefault(self): + record_defaults = [[]] * 2 + inputs = [['0,']] + self._test_dataset( + inputs, + expected_err_re='Field 1 is required but missing in record!', + record_defaults=record_defaults) + + def testCsvDataset_errorWithFewerDefaultsThanFields(self): + record_defaults = [[0.0]] * 2 + inputs = [['0,1,2,3']] + self._test_dataset( + inputs, + expected_err_re='Expect 2 fields but have more in record', + record_defaults=record_defaults) + + def testCsvDataset_errorWithMoreDefaultsThanFields(self): + record_defaults = [[0.0]] * 5 + inputs = [['0,1,2,3']] + self._test_dataset( + inputs, + expected_err_re='Expect 5 fields but have 4 in record', + record_defaults=record_defaults) + + def testCsvDataset_withHeader(self): + record_defaults = [[0]] * 2 + inputs = [['col1,col2', '1,2']] + expected = [[1, 2]] + self._test_dataset( + inputs, + expected, + record_defaults=record_defaults, + header=True, + ) + + def testCsvDataset_withHeaderAndNoRecords(self): + record_defaults = [[0]] * 2 + inputs = [['col1,col2']] + expected = [] + self._test_dataset( + inputs, + expected, + record_defaults=record_defaults, + header=True, + ) + + def testCsvDataset_errorWithHeaderEmptyFile(self): + record_defaults = [[0]] * 2 + inputs = [[]] + self._test_dataset( + inputs, + expected_err_re="Can't read header of empty file", + record_defaults=record_defaults, + header=True, + ) + + def testCsvDataset_withEmptyFile(self): + record_defaults = [['']] * 2 + inputs = [['']] # Empty file + self._test_dataset( + inputs, expected_output=[], record_defaults=record_defaults) + + def testCsvDataset_errorWithEmptyRecord(self): + record_defaults = [['']] * 2 + inputs = [['', '1,2']] # First record is empty + self._test_dataset( + inputs, + expected_err_re='Expect 2 fields but have 0 in record', + record_defaults=record_defaults) + + def testCsvDataset_withChainedOps(self): + # Testing that one dataset can create multiple iterators fine. + # `repeat` creates multiple iterators from the same C++ Dataset. + record_defaults = [[0]] * 4 + inputs = [['1,,3,4', '5,6,,8']] + ds_actual, ds_expected = self._make_test_datasets( + inputs, record_defaults=record_defaults) + with ops.Graph().as_default() as g: + self._assert_datasets_equal(g, + ds_actual.repeat(5).prefetch(1), + ds_expected.repeat(5).prefetch(1)) + + def testCsvDataset_withTypeDefaults(self): + # Testing using dtypes as record_defaults for required fields + record_defaults = [dtypes.float32, dtypes.float32] + inputs = [['1.0,2.0', '3.0,4.0']] + self._test_dataset( + inputs, + [[1.0, 2.0], [3.0, 4.0]], + record_defaults=record_defaults, + ) + + +class CsvDatasetBenchmark(test.Benchmark): + """Benchmarks for the various ways of creating a dataset from CSV files. + """ + + def _setUp(self): + # Since this isn't test.TestCase, have to manually create a test dir + gfile.MakeDirs(googletest.GetTempDir()) + self._temp_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) + + self._num_cols = [4, 64, 256] + self._batch_size = 500 + self._filenames = [] + for n in self._num_cols: + fn = os.path.join(self._temp_dir, 'file%d.csv' % n) + with open(fn, 'w') as f: + # Just write 10 rows and use `repeat`... + row = ','.join(['1.23456E12' for _ in range(n)]) + f.write('\n'.join([row for _ in range(10)])) + self._filenames.append(fn) + + def _tearDown(self): + gfile.DeleteRecursively(self._temp_dir) + + def _runBenchmark(self, dataset, num_cols, prefix): + next_element = dataset.make_one_shot_iterator().get_next() + with session.Session() as sess: + for _ in range(5): + sess.run(next_element) + deltas = [] + for _ in range(10): + start = time.time() + sess.run(next_element) + end = time.time() + deltas.append(end - start) + median_wall_time = np.median(deltas) / 100 + print('%s num_cols: %d Median wall time: %f' % (prefix, num_cols, + median_wall_time)) + self.report_benchmark( + iters=self._batch_size, + wall_time=median_wall_time, + name='%s_with_cols_%d' % (prefix, num_cols)) + + def benchmarkBatchThenMap(self): + self._setUp() + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [[0.0]] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + dataset = dataset.batch(self._batch_size) + self._runBenchmark(dataset, num_cols, 'csv_map_then_batch') + self._tearDown() + + def benchmarkCsvDataset(self): + self._setUp() + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [[0.0]] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop + dataset = dataset.batch(self._batch_size) + self._runBenchmark(dataset, num_cols, 'csv_fused_dataset') + self._tearDown() + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 1075302bae96ca2e0111efbacdf5e919ea76897d..e0237198b7d47eb98eeffe88d28bf9681b2722c6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -36,6 +36,7 @@ from tensorflow.python.framework import ops from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import string_ops from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -256,6 +257,29 @@ class TFRecordDatasetSerializationTest( lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) +def _interleave(iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 + + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 + + class ReadBatchFeaturesTest(test.TestCase): def setUp(self): @@ -355,8 +379,8 @@ class ReadBatchFeaturesTest(test.TestCase): yield j, i def _next_record_interleaved(file_indices, cycle_length): - return self._interleave([_next_record([i]) for i in file_indices], - cycle_length) + return _interleave([_next_record([i]) for i in file_indices], + cycle_length) file_batch = [] keywords_batch_indices = [] @@ -397,28 +421,6 @@ class ReadBatchFeaturesTest(test.TestCase): [len(file_batch), keywords_batch_max_len], record_batch ] - def _interleave(self, iterators, cycle_length): - pending_iterators = iterators - open_iterators = [] - num_open = 0 - for i in range(cycle_length): - if pending_iterators: - open_iterators.append(pending_iterators.pop(0)) - num_open += 1 - - while num_open: - for i in range(min(cycle_length, len(open_iterators))): - if open_iterators[i] is None: - continue - try: - yield next(open_iterators[i]) - except StopIteration: - if pending_iterators: - open_iterators[i] = pending_iterators.pop(0) - else: - open_iterators[i] = None - num_open -= 1 - def _verify_records(self, sess, batch_size, @@ -620,14 +622,12 @@ class MakeCsvDatasetTest(test.TestCase): f.close() return fn - def _create_file(self, fileno, header=True, comment=True): + def _create_file(self, fileno, header=True): rows = [] if header: rows.append(self.COLUMNS) for recno in range(self._num_records): rows.append(self._csv_values(fileno, recno)) - if comment: - rows.append("# Some comment goes here. Ignore me.") return self._write_file("csv_file%d.csv" % fileno, rows) def _create_files(self): @@ -648,9 +648,7 @@ class MakeCsvDatasetTest(test.TestCase): shuffle=False, shuffle_seed=None, header=True, - comment="#", na_value="", - default_float_type=dtypes.float32, ): return readers.make_csv_dataset( filenames, @@ -662,9 +660,7 @@ class MakeCsvDatasetTest(test.TestCase): shuffle=shuffle, shuffle_seed=shuffle_seed, header=header, - comment=comment, na_value=na_value, - default_float_type=default_float_type, select_columns=select_cols, ) @@ -786,29 +782,6 @@ class MakeCsvDatasetTest(test.TestCase): num_epochs=10, label_name=None) - def testMakeCSVDataset_withNoComments(self): - """Tests that datasets can be created from CSV files with no header line. - """ - defaults = self.DEFAULTS - file_without_header = self._create_file( - len(self._test_filenames), comment=False) - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - file_without_header, - defaults, - batch_size=2, - num_epochs=10, - comment=None, - ) - self._verify_records( - sess, - dataset, - [len(self._test_filenames)], - batch_size=2, - num_epochs=10, - ) - def testMakeCSVDataset_withNoHeader(self): """Tests that datasets can be created from CSV files with no header line. """ @@ -876,7 +849,7 @@ class MakeCsvDatasetTest(test.TestCase): In that case, we should infer the types from the first N records. """ - # Test that it works with standard test files (with comments, header, etc) + # Test that it works with standard test files (with header, etc) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = self._make_csv_dataset( @@ -889,7 +862,9 @@ class MakeCsvDatasetTest(test.TestCase): num_epochs=10, defaults=[[], [], [], [], [""]]) - # Test on a deliberately tricky file + def testMakeCSVDataset_withTypeInferenceTricky(self): + # Test on a deliberately tricky file (type changes as we read more rows, and + # there are null values) fn = os.path.join(self.get_temp_dir(), "file.csv") expected_dtypes = [ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float32, @@ -914,20 +889,29 @@ class MakeCsvDatasetTest(test.TestCase): column_names=None, label_name=None, na_value="NAN", - default_float_type=dtypes.float32, ) features = dataset.make_one_shot_iterator().get_next() # Check that types match for i in range(len(expected_dtypes)): + print(features["col%d" % i].dtype, expected_dtypes[i]) assert features["col%d" % i].dtype == expected_dtypes[i] for i in range(len(rows)): assert sess.run(features) == dict(zip(col_names, expected[i])) - # With float64 as default type for floats + def testMakeCSVDataset_withTypeInferenceAllTypes(self): + # Test that we make the correct inference for all types with fallthrough + fn = os.path.join(self.get_temp_dir(), "file.csv") expected_dtypes = [ - dtypes.int32, dtypes.int64, dtypes.float64, dtypes.float64, + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string, dtypes.string ] + col_names = ["col%d" % i for i in range(len(expected_dtypes))] + rows = [[1, 2**31 + 1, 1.0, 4e40, "abc", ""]] + expected = [[ + 1, 2**31 + 1, 1.0, 4e40, "abc".encode("utf-8"), "".encode("utf-8") + ]] + self._write_file("file.csv", [col_names] + rows) + with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = self._make_csv_dataset( @@ -936,7 +920,6 @@ class MakeCsvDatasetTest(test.TestCase): column_names=None, label_name=None, na_value="NAN", - default_float_type=dtypes.float64, ) features = dataset.make_one_shot_iterator().get_next() # Check that types match @@ -1086,5 +1069,189 @@ class MakeCsvDatasetTest(test.TestCase): self.assertFalse(all_equal) +class MakeTFRecordDatasetTest(TFRecordDatasetTestBase): + + def _next_expected_batch(self, + file_indices, + batch_size, + num_epochs, + cycle_length, + drop_final_batch, + use_parser_fn): + + def _next_record(file_indices): + for j in file_indices: + for i in range(self._num_records): + yield j, i + + def _next_record_interleaved(file_indices, cycle_length): + return _interleave([_next_record([i]) for i in file_indices], + cycle_length) + + record_batch = [] + batch_index = 0 + for _ in range(num_epochs): + if cycle_length == 1: + next_records = _next_record(file_indices) + else: + next_records = _next_record_interleaved(file_indices, cycle_length) + for f, r in next_records: + record = self._record(f, r) + if use_parser_fn: + record = record[1:] + record_batch.append(record) + batch_index += 1 + if len(record_batch) == batch_size: + yield record_batch + record_batch = [] + batch_index = 0 + if record_batch and not drop_final_batch: + yield record_batch + + def _verify_records(self, + sess, + outputs, + batch_size, + file_index, + num_epochs, + interleave_cycle_length, + drop_final_batch, + use_parser_fn): + if file_index is not None: + file_indices = [file_index] + else: + file_indices = range(self._num_files) + + for expected_batch in self._next_expected_batch( + file_indices, batch_size, num_epochs, interleave_cycle_length, + drop_final_batch, use_parser_fn): + actual_batch = sess.run(outputs) + self.assertAllEqual(expected_batch, actual_batch) + + def _read_test(self, batch_size, num_epochs, file_index=None, + num_parallel_reads=1, drop_final_batch=False, parser_fn=False): + if file_index is None: + file_pattern = self.test_filenames + else: + file_pattern = self.test_filenames[file_index] + + if parser_fn: + fn = lambda x: string_ops.substr(x, 1, 999) + else: + fn = None + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + outputs = readers.make_tf_record_dataset( + file_pattern=file_pattern, + num_epochs=num_epochs, + batch_size=batch_size, + parser_fn=fn, + num_parallel_reads=num_parallel_reads, + drop_final_batch=drop_final_batch, + shuffle=False).make_one_shot_iterator().get_next() + self._verify_records( + sess, outputs, batch_size, file_index, num_epochs=num_epochs, + interleave_cycle_length=num_parallel_reads, + drop_final_batch=drop_final_batch, use_parser_fn=parser_fn) + with self.assertRaises(errors.OutOfRangeError): + sess.run(outputs) + + def testRead(self): + for batch_size in [1, 2]: + for num_epochs in [1, 3]: + # Basic test: read from file 0. + self._read_test(batch_size, num_epochs, 0) + + # Basic test: read from file 1. + self._read_test(batch_size, num_epochs, 1) + + # Basic test: read from both files. + self._read_test(batch_size, num_epochs) + + # Basic test: read from both files, with parallel reads. + self._read_test(batch_size, num_epochs, num_parallel_reads=8) + + def testDropFinalBatch(self): + for batch_size in [1, 2, 10]: + for num_epochs in [1, 3]: + # Read from file 0. + self._read_test(batch_size, num_epochs, 0, drop_final_batch=True) + + # Read from both files. + self._read_test(batch_size, num_epochs, drop_final_batch=True) + + # Read from both files, with parallel reads. + self._read_test(batch_size, num_epochs, num_parallel_reads=8, + drop_final_batch=True) + + def testParserFn(self): + for batch_size in [1, 2]: + for num_epochs in [1, 3]: + for drop_final_batch in [False, True]: + self._read_test(batch_size, num_epochs, parser_fn=True, + drop_final_batch=drop_final_batch) + self._read_test(batch_size, num_epochs, num_parallel_reads=8, + parser_fn=True, drop_final_batch=drop_final_batch) + + def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1, + seed=None): + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.make_tf_record_dataset( + file_pattern=self.test_filenames, + num_epochs=num_epochs, + batch_size=batch_size, + num_parallel_reads=num_parallel_reads, + shuffle=True, + shuffle_seed=seed) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + sess.run(iterator.initializer) + first_batches = [] + try: + while True: + first_batches.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + + sess.run(iterator.initializer) + second_batches = [] + try: + while True: + second_batches.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + + self.assertEqual(len(first_batches), len(second_batches)) + if seed is not None: + # if you set a seed, should get the same results + for i in range(len(first_batches)): + self.assertAllEqual(first_batches[i], second_batches[i]) + + expected = [] + for f in range(self._num_files): + for r in range(self._num_records): + expected.extend([self._record(f, r)] * num_epochs) + + for batches in (first_batches, second_batches): + actual = [] + for b in batches: + actual.extend(b) + self.assertAllEqual(sorted(expected), sorted(actual)) + + def testShuffle(self): + for batch_size in [1, 2]: + for num_epochs in [1, 3]: + for num_parallel_reads in [1, 2]: + # Test that all expected elements are produced + self._shuffle_test(batch_size, num_epochs, num_parallel_reads) + # Test that elements are produced in a consistent order if + # you specify a seed. + self._shuffle_test(batch_size, num_epochs, num_parallel_reads, + seed=21345) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index e26cef8ec522c7e69a0c19b2b30a969bbfc0ad78..4148addf2878c99f47ebe1454edf69ad7f38dfbc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -22,6 +22,7 @@ import os import sqlite3 +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import readers from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -29,7 +30,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class SqlDatasetTest(test.TestCase): +class SqlDatasetTestBase(test.TestCase): def _createSqlDataset(self, output_types, num_repeats=1): dataset = readers.SqlDataset(self.driver_name, self.data_source_name, @@ -92,6 +93,9 @@ class SqlDatasetTest(test.TestCase): conn.commit() conn.close() + +class SqlDatasetTest(SqlDatasetTestBase): + # Test that SqlDataset can read from a database table. def testReadResultSet(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, @@ -652,5 +656,27 @@ class SqlDatasetTest(test.TestCase): sess.run(get_next) +class SqlDatasetSerializationTest( + SqlDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_repeats): + data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") + driver_name = array_ops.placeholder_with_default( + array_ops.constant("sqlite", dtypes.string), shape=[]) + query = ("SELECT first_name, last_name, motto FROM students ORDER BY " + "first_name DESC") + output_types = (dtypes.string, dtypes.string, dtypes.string) + return readers.SqlDataset(driver_name, data_source_name, query, + output_types).repeat(num_repeats) + + def testSQLSaveable(self): + num_repeats = 4 + num_outputs = num_repeats * 2 + self.run_core_tests(lambda: self._build_dataset(num_repeats), + lambda: self._build_dataset(num_repeats // 2), + num_outputs) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 7a3e42cc72755c67b910db99c0238f6ba780a942..eceecfd1744d0ae28953a4504450653efa473569 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -45,6 +45,27 @@ py_library( "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:framework_ops", "//tensorflow/python:training", + "//tensorflow/python/data/ops:iterator_ops", + ], +) + +py_test( + name = "iterator_ops_test", + size = "small", + srcs = ["iterator_ops_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", ], ) diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 42ec2b0b017973c60efb4c2e1c99a7a2292da58b..b9393de4e90ae2597045b29070934b94e18cfcbd 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -466,14 +466,14 @@ def assert_element_shape(expected_shapes): class _MapAndBatchDataset(dataset_ops.MapDataset): """A `Dataset` that maps a function over a batch of elements.""" - def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches, + def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, drop_remainder): """See `Dataset.map()` for details.""" super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) self._batch_size_t = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") - self._num_parallel_batches_t = ops.convert_to_tensor( - num_parallel_batches, dtype=dtypes.int64, name="num_parallel_batches") + self._num_parallel_calls_t = ops.convert_to_tensor( + num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") self._drop_remainder_t = ops.convert_to_tensor( drop_remainder, dtype=dtypes.bool, name="drop_remainder") @@ -483,12 +483,12 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): def _as_variant_tensor(self): # pylint: disable=protected-access input_resource = self._input_dataset._as_variant_tensor() - return gen_dataset_ops.map_and_batch_dataset( + return gen_dataset_ops.map_and_batch_dataset_v2( input_resource, self._map_func.captured_inputs, f=self._map_func, batch_size=self._batch_size_t, - num_parallel_batches=self._num_parallel_batches_t, + num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), @@ -511,8 +511,9 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): def map_and_batch(map_func, batch_size, - num_parallel_batches=1, - drop_remainder=False): + num_parallel_batches=None, + drop_remainder=False, + num_parallel_calls=None): """Fused implementation of `map` and `batch`. Maps `map_func` across `batch_size` consecutive elements of this dataset @@ -528,21 +529,37 @@ def map_and_batch(map_func, nested structure of tensors. batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of consecutive elements of this dataset to combine in a single batch. - num_parallel_batches: A `tf.int64` scalar `tf.Tensor`, representing the - number of batches to create in parallel. On one hand, higher values can - help mitigate the effect of stragglers. On the other hand, higher values - can increase contention if CPU is scarce. - drop_remainder: A `tf.bool` scalar `tf.Tensor`, representing whether the - last batch should be dropped in case its size is smaller than desired; - the default behavior is not to drop the smaller batch. + num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`, + representing the number of batches to create in parallel. On one hand, + higher values can help mitigate the effect of stragglers. On the other + hand, higher values can increase contention if CPU is scarce. + drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing + whether the last batch should be dropped in case its size is smaller than + desired; the default behavior is not to drop the smaller batch. + num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, + representing the number of elements to process in parallel. If not + specified, `batch_size * num_parallel_batches` elements will be + processed in parallel. Returns: A `Dataset` transformation function, which can be passed to @{tf.data.Dataset.apply}. + + Raises: + ValueError: If both `num_parallel_batches` and `num_parallel_calls` are + specified. """ + if num_parallel_batches is None and num_parallel_calls is None: + num_parallel_calls = batch_size + elif num_parallel_batches is not None and num_parallel_calls is None: + num_parallel_calls = batch_size * num_parallel_batches + elif num_parallel_batches is not None and num_parallel_calls is not None: + raise ValueError("The `num_parallel_batches` and `num_parallel_calls` " + "arguments are mutually exclusive.") + def _apply_fn(dataset): return _MapAndBatchDataset(dataset, map_func, batch_size, - num_parallel_batches, drop_remainder) + num_parallel_calls, drop_remainder) return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index d736029fb035e573b70e8b19570e4e8ceca3c005..0d71be66018eeebe60de9deff24ceb6854d209d9 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -16,10 +16,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.training import saver +from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import session_run_hook def make_saveable_from_iterator(iterator): @@ -60,14 +62,14 @@ def make_saveable_from_iterator(iterator): return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access -class _Saveable(saver.BaseSaverBuilder.SaveableObject): +class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject): """SaveableObject for saving/restoring iterator state.""" def __init__(self, iterator_resource): serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource) specs = [ - saver.BaseSaverBuilder.SaveSpec(serialized_iterator, "", - iterator_resource.name + "-state") + saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "", + iterator_resource.name + "-state") ] super(_Saveable, self).__init__(iterator_resource, specs, iterator_resource.name) @@ -75,3 +77,182 @@ class _Saveable(saver.BaseSaverBuilder.SaveableObject): def restore(self, restored_tensors, unused_restored_shapes): with ops.colocate_with(self.op): return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0]) + + +class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): + """Checkpoints input pipeline state every N steps or seconds. + + This hook saves the state of the iterators in the `Graph` so that when + training is resumed the input pipeline continues from where it left off. + This could potentially avoid overfitting in certain pipelines where the + number of training steps per eval are small compared to the dataset + size or if the training pipeline is pre-empted. + + Differences from `CheckpointSaverHook`: + 1. Saves only the input pipelines in the "iterators" collection and not the + global variables or other saveable objects. + 2. Does not write the `GraphDef` and `MetaGraphDef` to the summary. + + Example of checkpointing the training pipeline: + + ```python + est = tf.estimator.Estimator(model_fn) + while True: + est.train( + train_input_fn, + hooks=[tf.contrib.data.CheckpointInputPipelineHook(est)], + steps=train_steps_per_eval) + # Note: We do not pass the hook here. + metrics = est.evaluate(eval_input_fn) + if should_stop_the_training(metrics): + break + ``` + + This hook should be used if the input pipeline state needs to be saved + separate from the model checkpoint. Doing so may be useful for a few reasons: + 1. The input pipeline checkpoint may be large, if there are large shuffle + or prefetch buffers for instance, and may bloat the checkpoint size. + 2. If the input pipeline is shared between training and validation, restoring + the checkpoint during validation may override the validation input + pipeline. + + For saving the input pipeline checkpoint alongside the model weights use + @{tf.contrib.data.make_saveable_from_iterator} directly to create a + `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however, + that you will need to be careful not to restore the training iterator during + eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS + collector when building the eval graph. + """ + + def __init__(self, estimator): + """Initializes a `CheckpointInputPipelineHook`. + + Args: + estimator: Estimator. + + Raises: + ValueError: One of `save_steps` or `save_secs` should be set. + ValueError: At most one of saver or scaffold should be set. + """ + # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or + # of the form "input__.ckpt" for distributed pipelines. + # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is + # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix + # to be different to avoid conflicts with the model checkpoint. + + # pylint: disable=protected-access + checkpoint_prefix = "input" + if estimator._config.num_worker_replicas > 1: + # Distributed setting. + suffix = "_{}_{}".format(estimator._config.task_type, + estimator._config.task_id) + checkpoint_prefix += suffix + # pylint: enable=protected-access + + # We use a composition paradigm instead of inheriting from + # `CheckpointSaverHook` because `Estimator` does an `isinstance` check + # to check whether a `CheckpointSaverHook` is already present in the list + # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook` + # would thwart this behavior. This hook checkpoints *only the iterators* + # and not the graph variables. + self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook( + estimator.model_dir, + save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access + save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access + checkpoint_basename=checkpoint_prefix + ".ckpt") + + # Name for the protocol buffer file that will contain the list of most + # recent checkpoints stored as a `CheckpointState` protocol buffer. + # This file, kept in the same directory as the checkpoint files, is + # automatically managed by the `Saver` to keep track of recent checkpoints. + # The default name used by the `Saver` for this file is "checkpoint". Here + # we use the name "checkpoint_" so that in case the + # `checkpoint_dir` is the same as the model checkpoint directory, there are + # no conflicts during restore. + self._latest_filename = "checkpoint_" + checkpoint_prefix + self._first_run = True + + def begin(self): + # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS` + # collection if no `Saver` or `Scaffold` is provided. + # pylint: disable=protected-access + if (self._checkpoint_saver_hook._saver is None and + self._checkpoint_saver_hook._scaffold is None): + iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS) + saveables = [_Saveable(i) for i in iterators] + self._checkpoint_saver_hook._saver = _CustomSaver(saveables, + self._latest_filename) + # pylint: enable=protected-access + self._checkpoint_saver_hook.begin() + + def _restore_or_save_initial_ckpt(self, session): + # Ideally this should be run in after_create_session but is not for the + # following reason: + # Currently there is no way of enforcing an order of running the + # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook` + # is run *after* this hook. That is troublesome because + # 1. If a checkpoint exists and this hook restores it, the initializer hook + # will override it. + # 2. If no checkpoint exists, this hook will try to save an initialized + # iterator which will result in an exception. + # + # As a temporary fix we enter the following implicit contract between this + # hook and the _DatasetInitializerHook. + # 1. The _DatasetInitializerHook initializes the iterator in the call to + # after_create_session. + # 2. This hook saves the iterator on the first call to `before_run()`, which + # is guaranteed to happen after `after_create_session()` of all hooks + # have been run. + + # Check if there is an existing checkpoint. If so, restore from it. + # pylint: disable=protected-access + latest_checkpoint_path = saver_lib.latest_checkpoint( + self._checkpoint_saver_hook._checkpoint_dir, + latest_filename=self._latest_filename) + if latest_checkpoint_path: + self._checkpoint_saver_hook._get_saver().restore(session, + latest_checkpoint_path) + else: + # The checkpoint saved here is the state at step "global_step". + # Note: We do not save the GraphDef or MetaGraphDef here. + global_step = session.run(self._checkpoint_saver_hook._global_step_tensor) + self._checkpoint_saver_hook._save(session, global_step) + self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step) + # pylint: enable=protected-access + + def before_run(self, run_context): + if self._first_run: + self._restore_or_save_initial_ckpt(run_context.session) + self._first_run = False + return self._checkpoint_saver_hook.before_run(run_context) + + def after_run(self, run_context, run_values): + self._checkpoint_saver_hook.after_run(run_context, run_values) + + def end(self, session): + self._checkpoint_saver_hook.end(session) + + +class _CustomSaver(saver_lib.Saver): + """`Saver` with a different default `latest_filename`. + + This is used in the `CheckpointInputPipelineHook` to avoid conflicts with + the model ckpt saved by the `CheckpointSaverHook`. + """ + + def __init__(self, var_list, latest_filename): + super(_CustomSaver, self).__init__(var_list) + self._latest_filename = latest_filename + + def save(self, + sess, + save_path, + global_step=None, + latest_filename=None, + meta_graph_suffix="meta", + write_meta_graph=True, + write_state=True, + strip_default_attrs=False): + return super(_CustomSaver, self).save( + sess, save_path, global_step, latest_filename or self._latest_filename, + meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs) diff --git a/tensorflow/contrib/data/python/ops/iterator_ops_test.py b/tensorflow/contrib/data/python/ops/iterator_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..30a993b1f7056b9726f524b2279131339c80c5eb --- /dev/null +++ b/tensorflow/contrib/data/python/ops/iterator_ops_test.py @@ -0,0 +1,123 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for experimental iterator_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import iterator_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import training_util + + +class CheckpointInputPipelineHookTest(test.TestCase): + + @staticmethod + def _model_fn(features, labels, mode, config): + del labels + del mode + del config + global_step = training_util.get_or_create_global_step() + update_global_step_op = global_step.assign_add(1) + latest_feature = variables.Variable( + 0, name='latest_feature', dtype=dtypes.int64) + store_latest_feature_op = latest_feature.assign(features) + ops.add_to_collection('my_vars', global_step) + ops.add_to_collection('my_vars', latest_feature) + return model_fn.EstimatorSpec( + mode='train', + train_op=control_flow_ops.group( + [update_global_step_op, store_latest_feature_op]), + loss=constant_op.constant(2.0)) + + def _read_vars(self, model_dir): + """Returns (global_step, latest_feature).""" + with ops.Graph().as_default() as g: + ckpt_path = saver_lib.latest_checkpoint(model_dir) + meta_filename = ckpt_path + '.meta' + saver_lib.import_meta_graph(meta_filename) + saver = saver_lib.Saver() + with self.test_session(graph=g) as sess: + saver.restore(sess, ckpt_path) + return sess.run(ops.get_collection('my_vars')) + + def _build_iterator_saver_hook(self, est): + return iterator_ops.CheckpointInputPipelineHook(est) + + def testReturnDatasetFromInputFn(self): + + def _input_fn(): + return dataset_ops.Dataset.range(10) + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + + def testBuildIteratorInInputFn(self): + + def _input_fn(): + ds = dataset_ops.Dataset.range(10) + iterator = ds.make_one_shot_iterator() + return iterator.get_next() + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + + def testDoNotRestore(self): + + def _input_fn(): + return dataset_ops.Dataset.range(10) + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + # Hook not provided, input pipeline was not restored. + est.train(_input_fn, steps=2) + self.assertSequenceEqual(self._read_vars(est.model_dir), (6, 1)) + + def testRaiseErrorIfNoIterator(self): + + def _input_fn(): + return constant_op.constant(1, dtype=dtypes.int64) + + est = estimator.Estimator(model_fn=self._model_fn) + + with self.assertRaises(ValueError): + est.train( + _input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index bbb808fbd7730002e48cab47fa8d0fe09e2124d2..75c31a944a09462f534f6ae3e3204c812ecf28d9 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -18,15 +18,16 @@ from __future__ import division from __future__ import print_function import csv -from math import ceil import numpy as np from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -34,9 +35,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import file_io from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import string_ops from tensorflow.python.platform import gfile from tensorflow.python.util import deprecation @@ -68,7 +67,7 @@ def _is_valid_float(str_val, float_dtype): return False -def _infer_type(str_val, na_value, prev_type, float_dtype): +def _infer_type(str_val, na_value, prev_type): """Given a string, infers its tensor type. Infers the type of a value by picking the least 'permissive' type possible, @@ -79,29 +78,34 @@ def _infer_type(str_val, na_value, prev_type, float_dtype): na_value: Additional string to recognize as a NA/NaN CSV value. prev_type: Type previously inferred based on values of this column that we've seen up till now. - float_dtype: Either `tf.float32` or `tf.float64`. Denotes what float type - to parse float strings as. Returns: Inferred dtype. """ if str_val in ("", na_value): + # If the field is null, it gives no extra information about its type return prev_type - if _is_valid_int32(str_val) and prev_type in (None, dtypes.int32): - return dtypes.int32 + type_list = [ + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string + ] # list of types to try, ordered from least permissive to most - if _is_valid_int64(str_val) and prev_type in (None, dtypes.int32, - dtypes.int64): - return dtypes.int64 + type_functions = [ + _is_valid_int32, + _is_valid_int64, + lambda str_val: _is_valid_float(str_val, dtypes.float32), + lambda str_val: _is_valid_float(str_val, dtypes.float64), + lambda str_val: True, + ] # Corresponding list of validation functions - if _is_valid_float(str_val, float_dtype) and prev_type != dtypes.string: - return float_dtype + for i in range(len(type_list)): + validation_fn = type_functions[i] + if validation_fn(str_val) and (prev_type is None or + prev_type in type_list[:i + 1]): + return type_list[i] - return dtypes.string - -def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, - comment): +def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header): + """Generator that yields rows of CSV file(s) in order.""" for fn in filenames: with file_io.FileIO(fn, "r") as f: rdr = csv.reader( @@ -112,9 +116,6 @@ def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, next(rdr) # Skip header lines for csv_row in rdr: - if comment is not None and csv_row[0].startswith(comment): - continue # Skip comment lines - if len(csv_row) != num_cols: raise ValueError( "Problem inferring types: CSV row has different number of fields " @@ -123,22 +124,21 @@ def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim, - na_value, header, comment, float_dtype, - num_rows_for_inference, select_columns): + na_value, header, num_rows_for_inference, + select_columns): """Infers column types from the first N valid CSV records of files.""" if select_columns is None: select_columns = range(num_cols) inferred_types = [None] * len(select_columns) for i, csv_row in enumerate( - _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, - comment)): + _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header)): if num_rows_for_inference is not None and i >= num_rows_for_inference: break for j, col_index in enumerate(select_columns): inferred_types[j] = _infer_type(csv_row[col_index], na_value, - inferred_types[j], float_dtype) + inferred_types[j]) # Replace None's with a default type inferred_types = [t or dtypes.string for t in inferred_types] @@ -198,6 +198,112 @@ def _get_sorted_col_indices(select_columns, column_names): return result +def _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed): + """Optionally shuffle and repeat dataset, as requested.""" + if num_epochs != 1 and shuffle: + # Use shuffle_and_repeat for perf + return dataset.apply( + shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, + shuffle_seed)) + elif shuffle: + return dataset.shuffle(shuffle_buffer_size, shuffle_seed) + elif num_epochs != 1: + return dataset.repeat(num_epochs) + return dataset + + +def make_tf_record_dataset( + file_pattern, + batch_size, + parser_fn=None, + num_epochs=None, + shuffle=True, + shuffle_buffer_size=None, + shuffle_seed=None, + prefetch_buffer_size=None, + num_parallel_reads=None, + num_parallel_parser_calls=None, + drop_final_batch=False): + """Reads and optionally parses TFRecord files into a dataset. + + Provides common functionality such as batching, optional parsing, shuffling, + and performant defaults. + + Args: + file_pattern: List of files or patterns of TFRecord file paths. + See @{tf.gfile.Glob} for pattern rules. + batch_size: An int representing the number of records to combine + in a single batch. + parser_fn: (Optional.) A function accepting string input to parse + and process the record contents. This function must map records + to components of a fixed shape, so they may be batched. By + default, uses the record contents unmodified. + num_epochs: (Optional.) An int specifying the number of times this + dataset is repeated. If None (the default), cycles through the + dataset forever. + shuffle: (Optional.) A bool that indicates whether the input + should be shuffled. Defaults to `True`. + shuffle_buffer_size: (Optional.) Buffer size to use for + shuffling. A large buffer size ensures better shuffling, but + increases memory usage and startup time. + shuffle_seed: (Optional.) Randomization seed to use for shuffling. + prefetch_buffer_size: (Optional.) An int specifying the number of + feature batches to prefetch for performance improvement. + Defaults to auto-tune. Set to 0 to disable prefetching. + num_parallel_reads: (Optional.) Number of threads used to read + records from files. By default or if set to a value >1, the + results will be interleaved. + num_parallel_parser_calls: (Optional.) Number of parallel + records to parse in parallel. Defaults to an automatic selection. + drop_final_batch: (Optional.) Whether the last batch should be + dropped in case its size is smaller than `batch_size`; the + default behavior is not to drop the smaller batch. + + Returns: + A dataset, where each element matches the output of `parser_fn` + except it will have an additional leading `batch-size` dimension, + or a `batch_size`-length 1-D tensor of strings if `parser_fn` is + unspecified. + """ + files = dataset_ops.Dataset.list_files( + file_pattern, shuffle=shuffle, seed=shuffle_seed) + + if num_parallel_reads is None: + # Note: We considered auto-tuning this value, but there is a concern + # that this affects the mixing of records from different files, which + # could affect training convergence/accuracy, so we are defaulting to + # a constant for now. + num_parallel_reads = 24 + dataset = core_readers.TFRecordDataset( + files, num_parallel_reads=num_parallel_reads) + + if shuffle_buffer_size is None: + # TODO(josh11b): Auto-tune this value when not specified + shuffle_buffer_size = 10000 + dataset = _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + + if parser_fn is None: + if drop_final_batch: + dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) + else: + dataset = dataset.batch(batch_size) + else: + # TODO(josh11b): if num_parallel_parser_calls is None, use some function + # of num cores instead of map_and_batch's default behavior of one batch. + dataset = dataset.apply(batching.map_and_batch( + parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls, + drop_remainder=drop_final_batch)) + + if prefetch_buffer_size is None: + prefetch_buffer_size = -1 # tf.config.data.AUTOTUNE + if prefetch_buffer_size == 0: + return dataset + else: + return dataset.prefetch(buffer_size=prefetch_buffer_size) + + def make_csv_dataset( file_pattern, batch_size, @@ -209,7 +315,6 @@ def make_csv_dataset( use_quote_delim=True, na_value="", header=True, - comment=None, num_epochs=None, shuffle=True, shuffle_buffer_size=10000, @@ -218,7 +323,6 @@ def make_csv_dataset( num_parallel_reads=1, num_parallel_parser_calls=2, sloppy=False, - default_float_type=dtypes.float32, num_rows_for_inference=100, ): """Reads CSV files into a dataset. @@ -231,8 +335,8 @@ def make_csv_dataset( Args: file_pattern: List of files or patterns of file paths containing CSV records. See @{tf.gfile.Glob} for pattern rules. - batch_size: An int representing the number of consecutive elements of this - dataset to combine in a single batch. + batch_size: An int representing the number of records to combine + in a single batch. column_names: An optional list of strings that corresponds to the CSV columns, in order. One per column of the input record. If this is not provided, infers the column names from the first row of the records. @@ -272,15 +376,11 @@ def make_csv_dataset( header: A bool that indicates whether the first rows of provided CSV files correspond to header lines with column names, and should not be included in the data. - comment: An optional character string that marks lines that should not be - parsed as csv records. If this is provided, all lines that start with - this character will not be parsed. num_epochs: An int specifying the number of times this dataset is repeated. If None, cycles through the dataset forever. shuffle: A bool that indicates whether the input should be shuffled. shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size - ensures better shuffling, but would increase memory usage and startup - time. + ensures better shuffling, but increases memory usage and startup time. shuffle_seed: Randomization seed to use for shuffling. prefetch_buffer_size: An int specifying the number of feature batches to prefetch for performance improvement. Recommended value is the number of @@ -294,8 +394,6 @@ def make_csv_dataset( produced is deterministic prior to shuffling (elements are still randomized if `shuffle=True`. Note that if the seed is set, then order of elements after shuffling is deterministic). Defaults to `False`. - default_float_type: Either `tf.float32` or `tf.float64`. If defaults are - not provided, float-like strings are interpreted to be this type. num_rows_for_inference: Number of rows of a file to use for type inference if record_defaults is not provided. If None, reads all the rows of all the files. Defaults to 100. @@ -317,8 +415,6 @@ def make_csv_dataset( dataset = dataset.shuffle(len(filenames), shuffle_seed) # Clean arguments; figure out column names and defaults - if comment is not None and len(comment) != 1: - raise ValueError("`comment` arg must be a single-character string or None") if column_names is None: if not header: @@ -341,8 +437,7 @@ def make_csv_dataset( # construction time column_defaults = _infer_column_defaults( filenames, len(column_names), field_delim, use_quote_delim, na_value, - header, comment, default_float_type, num_rows_for_inference, - select_columns) + header, num_rows_for_inference, select_columns) if select_columns is not None and len(column_defaults) != len(select_columns): raise ValueError( @@ -356,71 +451,189 @@ def make_csv_dataset( if label_name is not None and label_name not in column_names: raise ValueError("`label_name` provided must be one of the columns.") - # Define map and filter functions - def filter_fn(line): - return math_ops.not_equal(string_ops.substr(line, 0, 1), comment) - def filename_to_dataset(filename): - ds = core_readers.TextLineDataset(filename) - if header: - ds = ds.skip(1) - if comment is not None: - ds = ds.filter(filter_fn) - return ds + return CsvDataset( + filename, + record_defaults=column_defaults, + field_delim=field_delim, + use_quote_delim=use_quote_delim, + na_value=na_value, + select_cols=select_columns, + header=header) - def decode_csv(line): - """Decodes CSV line into features. + def map_fn(*columns): + """Organizes columns into a features dictionary. Args: - line: String tensor corresponding to one csv record. + *columns: list of `Tensor`s corresponding to one csv record. Returns: A dictionary of feature names to values for that particular record. If label_name is provided, extracts the label feature to be returned as the second element of the tuple. """ - columns = parsing_ops.decode_csv( - line, - column_defaults, - field_delim=field_delim, - use_quote_delim=use_quote_delim, - na_value=na_value, - select_cols=select_columns, - ) features = dict(zip(column_names, columns)) if label_name is not None: label = features.pop(label_name) return features, label return features - # Read files sequentially or in parallel + # Read files sequentially (if num_parallel_reads=1) or in parallel dataset = dataset.apply( interleave_ops.parallel_interleave( filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy)) - if num_epochs != 1 and shuffle: - # Use shuffle_and_repeat for perf - dataset = dataset.apply( - shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, - shuffle_seed)) - elif shuffle: - dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed) - elif num_epochs != 1: - dataset = dataset.repeat(num_epochs) - - # Use map_and_batch for perf - # TODO(b/76425672): use num_parallel_calls for better performance tuning when - # that is added - dataset = dataset.apply( - batching.map_and_batch( - map_func=decode_csv, - batch_size=batch_size, - num_parallel_batches=int( - ceil(num_parallel_parser_calls / batch_size)))) + dataset = _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + # Apply batch before map for perf, because map has high overhead relative + # to the size of the computation in each map + dataset = dataset.batch(batch_size=batch_size) + dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_parser_calls) dataset = dataset.prefetch(prefetch_buffer_size) + return dataset +_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB + + +class CsvDataset(dataset_ops.Dataset): + """A Dataset comprising lines from one or more CSV files.""" + + def __init__(self, + filenames, + record_defaults, + buffer_size=None, + header=False, + field_delim=",", + use_quote_delim=True, + na_value="", + select_cols=None): + """Creates a `CsvDataset` by reading and decoding CSV files. + + The elements of this dataset correspond to records from the file(s). + RFC 4180 format is expected for CSV files + (https://tools.ietf.org/html/rfc4180) + Note that we allow leading and trailing spaces with int or float field. + + + For example, suppose we have a file 'my_file0.csv' with four CSV columns of + different data types: + ``` + abcdefg,4.28E10,5.55E6,12 + hijklmn,-5.3E14,,2 + ``` + + We can construct a CsvDataset from it as follows: + ```python + dataset = tf.contrib.data.CsvDataset( + "my_file*.csv", + [tf.float32, # Required field, use dtype or empty tensor + tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0 + tf.int32, # Required field, use dtype or empty tensor + ], + select_cols=[1,2,3] # Only parse last three columns + ) + ``` + + The expected output of its iterations is: + ```python + next = dataset.make_one_shot_iterator().get_next() + with tf.Session() as sess: + while True: + try: + print(sess.run(nxt)) + except tf.errors.OutOfRangeError: + break + + >> (4.28e10, 5.55e6, 12) + >> (-5.3e14, 0.0, 2) + ``` + + Args: + filenames: A `tf.string` tensor containing one or more filenames. + record_defaults: A list of default values for the CSV fields. Each item in + the list is either a valid CSV `DType` (float32, float64, int32, int64, + string), or a `Tensor` object with one of the above types. One per + column of CSV data, with either a scalar `Tensor` default value for the + column if it is optional, or `DType` or empty `Tensor` if required. If + both this and `select_columns` are specified, these must have the same + lengths, and `column_defaults` is assumed to be sorted in order of + increasing column index. + buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes + to buffer while reading files. Defaults to 4MB. + header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s) + have header line(s) that should be skipped when parsing. Defaults to + `False`. + field_delim: (Optional.) A `tf.string` scalar containing the delimiter + character that separates fields in a record. Defaults to `","`. + use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats + double quotation marks as regular characters inside of string fields + (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`. + na_value: (Optional.) A `tf.string` scalar indicating a value that will + be treated as NA/NaN. + select_cols: (Optional.) A sorted list of column indices to select from + the input data. If specified, only this subset of columns will be + parsed. Defaults to parsing all columns. + """ + super(CsvDataset, self).__init__() + self._filenames = ops.convert_to_tensor( + filenames, dtype=dtypes.string, name="filenames") + record_defaults = [ + constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x + for x in record_defaults + ] + self._record_defaults = ops.convert_n_to_tensor( + record_defaults, name="record_defaults") + self._buffer_size = convert.optional_param_to_tensor( + "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) + self._header = ops.convert_to_tensor( + header, dtype=dtypes.bool, name="header") + self._field_delim = ops.convert_to_tensor( + field_delim, dtype=dtypes.string, name="field_delim") + self._use_quote_delim = ops.convert_to_tensor( + use_quote_delim, dtype=dtypes.bool, name="use_quote_delim") + self._na_value = ops.convert_to_tensor( + na_value, dtype=dtypes.string, name="na_value") + self._select_cols = convert.optional_param_to_tensor( + "select_cols", + select_cols, + argument_default=[], + argument_dtype=dtypes.int64, + ) + self._output_shapes = tuple( + tensor_shape.scalar() for _ in range(len(record_defaults))) + self._output_types = tuple(d.dtype for d in self._record_defaults) + self._output_classes = tuple( + ops.Tensor for _ in range(len(record_defaults))) + + def _as_variant_tensor(self): + # Constructs graph node for the dataset op. + return contrib_gen_dataset_ops.csv_dataset( + filenames=self._filenames, + record_defaults=self._record_defaults, + buffer_size=self._buffer_size, + header=self._header, + output_shapes=self._output_shapes, + field_delim=self._field_delim, + use_quote_delim=self._use_quote_delim, + na_value=self._na_value, + select_cols=self._select_cols, + ) + + @property + def output_types(self): + return self._output_types + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_classes(self): + return self._output_classes + + def make_batched_features_dataset(file_pattern, batch_size, features, @@ -480,8 +693,8 @@ def make_batched_features_dataset(file_pattern, Args: file_pattern: List of files or patterns of file paths containing `Example` records. See `tf.gfile.Glob` for pattern rules. - batch_size: An int representing the number of consecutive elements of this - dataset to combine in a single batch. + batch_size: An int representing the number of records to combine + in a single batch. features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. See `tf.parse_example`. reader: A function or class that can be @@ -537,16 +750,8 @@ def make_batched_features_dataset(file_pattern, dataset = dataset.map(lambda _, v: v) # Apply dataset repeat and shuffle transformations. - repeat_dataset = (num_epochs != 1) - if repeat_dataset and shuffle: - # Used fused shuffle_and_repeat operation for better performance - dataset = dataset.apply( - shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, - shuffle_seed)) - elif repeat_dataset: - dataset = dataset.repeat(num_epochs) - elif shuffle: - dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed) + dataset = _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) if drop_final_batch: dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) @@ -620,8 +825,8 @@ def read_batch_features(file_pattern, Args: file_pattern: List of files or patterns of file paths containing `Example` records. See `tf.gfile.Glob` for pattern rules. - batch_size: An int representing the number of consecutive elements of this - dataset to combine in a single batch. + batch_size: An int representing the number of records to combine + in a single batch. features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. See `tf.parse_example`. reader: A function or class that can be diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 8dfcaf6032e1602ed76a8a995553c5d398c4a778..64a77bbed1d55c3d95329d9c7783c2b468bde745 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -26,7 +26,6 @@ py_library( "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/contrib/eager/python:datasets", "//tensorflow/python:array_ops", - "//tensorflow/python:checkpointable", "//tensorflow/python:control_flow_ops", "//tensorflow/python:device_util", "//tensorflow/python:distribute", @@ -34,6 +33,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/eager:context", + "//tensorflow/python/training/checkpointable:base", "@six_archive//:six", ], ) @@ -151,6 +151,7 @@ py_library( ":one_device_strategy", ":tpu_strategy", "//tensorflow/contrib/optimizer_v2:training", + "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python:util", @@ -469,24 +470,24 @@ py_library( ], ) -py_test( +cuda_py_test( name = "cross_tower_ops_test", srcs = ["cross_tower_ops_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], - deps = [ + additional_deps = [ ":combinations", ":cross_tower_ops", ":values", + "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", - "@absl_py//absl/testing:parameterized", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", ], ) diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 946310aa6fc2101d75e86d3ff2e9f3284e6c6625..15935817b0283ebc04b95304afe41d8690a11442 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -51,6 +51,7 @@ from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.training import adam +from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import gradient_descent from tensorflow.python.util import tf_inspect @@ -69,26 +70,31 @@ def generate(combinations): -- there should always be a "mode" argument. Accepted values are "eager" and "graph". -- arguments of the test method must match by name to get the corresponding - value of the combination. Tests must accept all arguments (except "mode", - which is optional). - -- distribution argument is special. It is meant for passing instances of - DistributionStrategy. Each instance is to be passed as `(, - )` tuple, where is the number of required - GPUs. If the required number of GPUs for the DistributionStrategy isn't - available then the test case is going to be skipped. + value of the combination. Tests must accept all arguments except the + "mode", "required_tpu" and "required_gpus". + -- "distribution" argument is special and optional. It is meant for passing + instances of DistributionStrategy. Each instance is to be passed as via + `NamedDistribution`. If using "distribution", "required_gpus" and + "required_tpu" should be specified via the NamedDistribution instance, + rather than as separate arguments. + -- "required_tpu" argument is special and optional. If not `None`, then the + test will be skipped if TPUs aren't available. + -- "required_gpus" argument is special and optional. If not `None`, then the + test will be skipped if the specified number of GPUs aren't available. Returns: a decorator that will cause the test method to be run under the specified conditions. Raises: - ValueError - if "mode" argument wasn't either "eager" or "graph. + ValueError - if "mode" argument wasn't either "eager" or "graph". """ def decorator(test_function): """The decorator to be returned.""" # Generate good test names that can be used with --test_filter. + named_combinations = [] for combination in combinations: # We use OrderedDicts in `combine()` and `times()` to ensure stable # order of keys in each dictionary. @@ -99,30 +105,46 @@ def generate(combinations): "".join(filter(str.isalnum, str(value)))) for key, value in combination.items() ]) - combination.update({"testcase_name": "_test{}".format(name)}) + named_combinations.append( + OrderedDict( + list(combination.items()) + [("testcase_name", + "_test{}".format(name))])) - @parameterized.named_parameters(*combinations) + @parameterized.named_parameters(*named_combinations) def decorated(self, **kwargs): """A wrapped test method that sets up `test_function`.""" assert "mode" in kwargs mode = kwargs["mode"] - if "distribution" in kwargs: - distribution = kwargs["distribution"] - kwargs["distribution"] = distribution.strategy - if distribution.required_tpu and not TPU_TEST: - self.skipTest("Test requires a TPU, but it's not available.") - if not distribution.required_tpu and TPU_TEST: - self.skipTest("Test that doesn't require a TPU.") - - if not distribution.required_gpus: - if GPU_TEST: - self.skipTest("Test that doesn't require GPUs.") - elif context.num_gpus() < distribution.required_gpus: - self.skipTest( - "{} GPUs are not available for this test. {} GPUs are available". - format(distribution.required_gpus, context.num_gpus())) + distribution = kwargs.pop("distribution", None) + required_tpu = kwargs.pop("required_tpu", False) + required_gpus = kwargs.pop("required_gpus", None) + if distribution: + assert required_gpus is None, ( + "Do not use `required_gpus` and `distribution` together.") + assert required_tpu is False, ( + "Do not use `required_tpu` and `distribution` together.") + kwargs["distribution"] = distribution.strategy + required_gpus = distribution.required_gpus + required_tpu = distribution.required_tpu + + if required_tpu and not TPU_TEST: + self.skipTest("Test requires a TPU, but it's not available.") + if not required_tpu and TPU_TEST: + self.skipTest("Test that doesn't require a TPU.") + + if not required_gpus: + if GPU_TEST: + self.skipTest("Test that doesn't require GPUs.") + elif context.num_gpus() < required_gpus: + self.skipTest( + "{} GPUs are not available for this test. {} GPUs are available". + format(required_gpus, context.num_gpus())) + + # At this point, `kwargs` doesn't have `required_gpus` or `required_tpu` + # that the user might have specified. `kwargs` still has `mode`, which + # the test is allowed to accept or ignore. requested_arguments = tf_inspect.getfullargspec(test_function).args missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( set(requested_arguments + ["mode"])) @@ -159,7 +181,8 @@ def combine(**kwargs): can be computed using `times()`. Args: - **kwargs: keyword arguments of form `option=[possibilities, ...]`. + **kwargs: keyword arguments of form `option=[possibilities, ...]` + or `option=the_only_possibility`. Returns: a list of dictionaries for each combination. Keys in the dictionaries are @@ -178,6 +201,8 @@ def combine(**kwargs): key = first[0] values = first[1] + if not isinstance(values, list): + values = [values] return [ OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key)) @@ -262,21 +287,31 @@ class NamedDistribution(object): return self._required_tpu +default_strategy = NamedDistribution( + "Default", + distribute_lib._default_distribution_strategy, # pylint: disable=protected-access + required_gpus=None) one_device_strategy = NamedDistribution( "OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"), - None) + required_gpus=None) +tpu_strategy_single_iteration = NamedDistribution( + "TPUSingleIteration", + tpu_strategy.TPUStrategy(iterations_per_step=1), + required_tpu=True) tpu_strategy = NamedDistribution( "TPU", tpu_strategy.TPUStrategy(), required_tpu=True) +# Note that we disable prefetching for testing since prefetching makes +# the input non-deterministic. mirrored_strategy_with_gpu_and_cpu = NamedDistribution( "MirroredCPUAndGPU", - mirrored_strategy.MirroredStrategy(["/gpu:0", "/cpu:0"]), 1) -mirrored_strategy_without_prefetch = NamedDistribution( - "MirroredCPUAndGPUNoPrefetch", mirrored_strategy.MirroredStrategy( - ["/gpu:0", "/cpu:0"], prefetch_on_device=False), 1) + ["/gpu:0", "/cpu:0"], prefetch_on_device=False), + required_gpus=1) mirrored_strategy_with_two_gpus = NamedDistribution( "Mirrored2GPUs", - mirrored_strategy.MirroredStrategy(["/gpu:0", "/gpu:1"]), 2) + mirrored_strategy.MirroredStrategy( + ["/gpu:0", "/gpu:1"], prefetch_on_device=False), + required_gpus=2) adam_optimizer_v1_fn = NamedObject( "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1)) diff --git a/tensorflow/contrib/distribute/python/combinations_test.py b/tensorflow/contrib/distribute/python/combinations_test.py index 219b24160f3902fcfa5363cc39a8fc5b30d00308..184bcf27e59d68b82d28e8f01890c04f214c017c 100644 --- a/tensorflow/contrib/distribute/python/combinations_test.py +++ b/tensorflow/contrib/distribute/python/combinations_test.py @@ -41,6 +41,15 @@ class TestingCombinationsTest(test.TestCase): "b": 3 }], combinations.combine(a=[1, 2], b=[2, 3])) + def test_combine_single_parameter(self): + self.assertEqual([{ + "a": 1, + "b": 2 + }, { + "a": 2, + "b": 2 + }], combinations.combine(a=[1, 2], b=2)) + def test_add(self): self.assertEqual( [{ diff --git a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py index b87224251ca3844fc81c6f32a893d2c71664a955..2b05884b9b93470ef9a764cbedbc91bd3912c611 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py +++ b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""An example tf.keras model that is trained using MirroredStrategy.""" +"""An example of training tf.keras Model using MirroredStrategy.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from sys import argv + +import sys + import numpy as np import tensorflow as tf @@ -33,30 +35,37 @@ def input_fn(): def main(args): if len(args) < 2: - print('You must specify model_dir for checkpoints such as' - ' /tmp/tfkeras_example./') + print('You must specify model_dir for checkpoints such as' + ' /tmp/tfkeras_example/.') return - print('Using %s to store checkpoints.' % args[1]) - - strategy = tf.contrib.distribute.MirroredStrategy( - ['/device:GPU:0', '/device:GPU:1']) - config = tf.estimator.RunConfig(train_distribute=strategy) - optimizer = tf.train.GradientDescentOptimizer(0.2) + model_dir = args[1] + print('Using %s to store checkpoints.' % model_dir) + # Define tf.keras Model. model = tf.keras.Sequential() model.add(tf.keras.layers.Dense(16, activation='relu', input_shape=(10,))) model.add(tf.keras.layers.Dense(1, activation='sigmoid')) + # Compile tf.keras Model. + optimizer = tf.train.GradientDescentOptimizer(0.2) model.compile(loss='binary_crossentropy', optimizer=optimizer) model.summary() tf.keras.backend.set_learning_phase(True) + + # Define a DistributionStrategy and convert the tf.keras Model to a + # tf.Estimator that utilizes the DistributionStrategy. + strategy = tf.contrib.distribute.MirroredStrategy( + ['/device:GPU:0', '/device:GPU:1']) + config = tf.estimator.RunConfig(train_distribute=strategy) keras_estimator = tf.keras.estimator.model_to_estimator( - keras_model=model, config=config, model_dir=args[1]) + keras_model=model, config=config, model_dir=model_dir) + # Train and evaluate the tf.Estimator. keras_estimator.train(input_fn=input_fn, steps=10) eval_result = keras_estimator.evaluate(input_fn=input_fn) print('Eval result: {}'.format(eval_result)) + if __name__ == '__main__': - tf.app.run(argv=argv) + tf.app.run(argv=sys.argv) diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index e134fe34e10be402f028db986b8cbf14222db07f..5c056a7c73def2f1fb4bbe0df4d3f82fdabda3df 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -44,13 +44,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): combinations.distributions_and_v1_optimizers(), combinations.combine(mode=["graph"], use_callable_loss=[True, False]) + combinations.combine(mode=["eager"], use_callable_loss=[True]), - combinations.combine(is_tpu=[False])) + - combinations.combine( - distribution=[combinations.tpu_strategy], - optimizer_fn=[combinations.adam_optimizer_v1_fn], - mode=["graph"], - use_callable_loss=[False], - is_tpu=[True])) + combinations.combine(is_tpu=[False])) + combinations.combine( + distribution=[combinations.tpu_strategy], + optimizer_fn=[ + combinations.adam_optimizer_v1_fn, + # TODO(isaprykin): Make Adam v2 work with while_loops + # and TPUs. + ], + mode=["graph"], + use_callable_loss=[False], + is_tpu=[True])) def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss, is_tpu): with distribution.scope(): @@ -101,7 +104,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): distribution=[combinations.tpu_strategy], optimizer_fn=[ combinations.adam_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v1_fn + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v2_fn, ], mode=["graph"], is_tpu=[True])) @@ -171,13 +175,28 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): set(created_variables)) @combinations.generate( - combinations.times(combinations.distributions_and_v1_optimizers(), - combinations.combine( - mode=["graph", "eager"], - momentum=[0.8, 0.9, 0.99], - renorm=[False, True]))) + combinations.times( + combinations.combine(momentum=[0.8, 0.9, 0.99], renorm=[False, True]), + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine( + mode=["graph", "eager"], + is_tpu=[False], + # TODO(isaprykin): Allow False here. Currently subsequent + # towers will re-execute UPDATE_OPS of previous towers. + update_ops_in_cross_tower_mode=[True])) + + combinations.combine( + distribution=[combinations.tpu_strategy_single_iteration], + optimizer_fn=[ + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v2_fn + ], + mode=["graph"], + is_tpu=[True], + update_ops_in_cross_tower_mode=[False]))) def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum, - renorm): + renorm, is_tpu, + update_ops_in_cross_tower_mode): """Verifies that moving mean updates are reduced across towers.""" with distribution.scope(): num_towers = len(distribution.worker_devices) @@ -185,27 +204,30 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): optimizer_fn, batch_per_epoch=num_towers, momentum=momentum, - renorm=renorm) + renorm=renorm, + update_ops_in_tower_mode=not update_ops_in_cross_tower_mode) - # Disable prefetching since that makes the specific input on each device - # to be non deterministic, and this test relies on specific input being - # on each device. + # Make sure prefetching is disabled since that makes the + # specific input on each device to be non deterministic, and + # this test relies on specific input being on each device. if isinstance(distribution, mirrored_strategy.MirroredStrategy): - distribution._prefetch_on_device = False + self.assertFalse(distribution._prefetch_on_device) iterator = distribution.distribute_dataset( dataset_fn).make_one_shot_iterator() def run_step(): - return control_flow_ops.group( - distribution.unwrap( - distribution.call_for_each_tower( - model_fn, - iterator.get_next(), - run_concurrently=batchnorm.built)) + - ops.get_collection(ops.GraphKeys.UPDATE_OPS)) + fetches = distribution.unwrap( + distribution.call_for_each_tower( + model_fn, iterator.get_next(), + run_concurrently=batchnorm.built)) + if update_ops_in_cross_tower_mode: + fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) + return control_flow_ops.group(fetches) if not context.executing_eagerly(): with self.test_session() as sess: + if is_tpu: + sess.run(tpu.initialize_system()) run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) @@ -229,22 +251,40 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) + if is_tpu: + with self.test_session() as sess: + sess.run(tpu.shutdown_system()) + @combinations.generate( combinations.times( combinations.combine( - distribution=[combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], - optimizer_fn=[combinations.gradient_descent_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v2_fn], - loss_reduction=[losses_impl.Reduction.SUM, - losses_impl.Reduction.MEAN, - losses_impl.Reduction.SUM_OVER_BATCH_SIZE, - losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS]), - combinations.combine(mode=["graph"], use_callable_loss=[True, False]) - + combinations.combine(mode=["eager"], use_callable_loss=[True]))) + optimizer_fn=[ + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v2_fn + ], + loss_reduction=[ + losses_impl.Reduction.SUM, losses_impl.Reduction.MEAN, + losses_impl.Reduction.SUM_OVER_BATCH_SIZE, + losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS + ]), + combinations.times( + combinations.combine( + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus + ], + is_tpu=[False]), + combinations.combine( + mode=["graph"], use_callable_loss=[True, False]) + + combinations.combine(mode=["eager"], use_callable_loss=[True])) + + combinations.combine( + distribution=[combinations.tpu_strategy_single_iteration], + is_tpu=[True], + mode=["graph"], + use_callable_loss=[True, False]))) def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction, - use_callable_loss): + use_callable_loss, is_tpu): with distribution.scope(): all_vars = [] @@ -280,12 +320,13 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): if not context.executing_eagerly(): with self.test_session() as sess: + if is_tpu: + sess.run(tpu.initialize_system()) run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) run_step() - self.assertEqual(distribution.num_towers, len(all_vars)) v = all_vars[0] self.assertTrue(all([v is vi for vi in all_vars[1:]])) weight = numpy.squeeze(self.evaluate(distribution.fetch(v))) @@ -312,6 +353,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # One of the mean loss reductions. self.assertNear(weight, 2 + 10.6, 0.0001) + if is_tpu: + with self.test_session() as sess: + sess.run(tpu.shutdown_system()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 8237b23dbbdb10c053de53880d6838113b99be2d..89f2c431fece63269928fec6aa6d23b5a79ba0b9 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -111,10 +111,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): kwargs["name"] = "%s/replica_%d" % (var0name, i) # Initialize replicas with the same value: if context.executing_eagerly(): - initial_value = index[devices[0]].value() + kwargs["initial_value"] = array_ops.identity( + index[devices[0]].value()) else: - initial_value = index[devices[0]].initial_value - kwargs["initial_value"] = array_ops.identity(initial_value) + def initial_value_fn(device=d): + with ops.device(device): + return array_ops.identity(index[devices[0]].initial_value) + kwargs["initial_value"] = initial_value_fn with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 6c5c055070c0fc88ed8f3a459e3f346596f077a6..3f9a02b249dde9a66056ed8952b664bbc3f74ead 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -28,9 +28,12 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.layers import core +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib @@ -116,7 +119,6 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): self.assertEqual(expected, self.evaluate(unwrapped[0])) -@test_util.with_c_api class MirroredStrategyVariableCreationTest(test.TestCase): config = config_pb2.ConfigProto() @@ -370,22 +372,27 @@ class MirroredStrategyVariableCreationTest(test.TestCase): expected_sum = 0.0 expected_mean = 0.0 for i, d in enumerate(dist.worker_devices): - # Test access within a device scope, should see different values. - with ops.device(d): - v_sum_value = self.evaluate(ret_v_sum.read_value()) - v_mean_value = self.evaluate(ret_v_mean.read_value()) - expected = i + 3.0 - self.assertEqual(expected, v_sum_value) - expected_sum += expected - expected = i * 6.0 - self.assertEqual(expected, v_mean_value) - expected_mean += expected - - # fetch() should return the value you get by applying the - # reduction across all towers. - self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum))) + # Should see different values on different devices. + v_sum_value = self.evaluate(ret_v_sum.get(d).read_value()) + v_mean_value = self.evaluate(ret_v_mean.get(d).read_value()) + expected = i + 3.0 + self.assertEqual(expected, v_sum_value) + expected_sum += expected + expected = i * 6.0 + self.assertEqual(expected, v_mean_value) + expected_mean += expected expected_mean /= len(dist.worker_devices) + + # Without get(device), should return the value you get by + # applying the reduction across all towers (whether you use + # fetch(), get(), or nothing). + self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum))) self.assertEqual(expected_mean, self.evaluate(dist.fetch(ret_v_mean))) + self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get())) + self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get())) + if not context.executing_eagerly(): + self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) + self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not # testing this in eager mode. @@ -431,6 +438,30 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertEquals("foo/" + name + ":0", v0.name) self.assertEquals("tower_1/foo/" + name + ":0", v1.name) + def testDynamicRnnVariables(self): + def model_fn(): + inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) + cell_fw = rnn_cell_impl.LSTMCell(300) + cell_bw = rnn_cell_impl.LSTMCell(300) + (outputs, _) = rnn.bidirectional_dynamic_rnn( + cell_fw, + cell_bw, + inputs, + dtype=dtypes.float32) + return outputs + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + # Two variables are created by the RNN layer. + self.assertEquals(2, len(result)) + for v in result: + self.assertIsInstance(v, values.DistributedValues) + _, v1 = dist.unwrap(v) + self.assertStartsWith(v1.name, "tower_1/") + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py index a1ef0ecc77a8e8432dfa4eb6da7c324b371dab70..61cbe6df813bb28bf8baa83d9e28ffafc4f0cbb8 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py @@ -27,7 +27,6 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.training import distribute as distribute_lib -@test_util.with_c_api class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): def _get_distribution_strategy(self): @@ -53,7 +52,6 @@ class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): self._test_call_and_merge_exceptions(self._get_distribution_strategy()) -@test_util.with_c_api class VariableCreatorStackTest(test.TestCase): def testCreatorStacksAreThreadLocal(self): diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py index ee7588163e42ee3c31dd9fd25fc53e3483f0fbee..09c859b32a3150b95fbfcfa5b62b5eca426ddf18 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py +++ b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py @@ -25,11 +25,9 @@ from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.training import server_lib -@test_util.with_c_api class MultiWorkerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, strategy_test_lib.DistributionTestBase): diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index 7101ed0756f44b846f10ddc6d429afe005a2f196..7aad8a953cbedd30b48739416e74b3dc164dc4cd 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -24,7 +24,6 @@ from tensorflow.python.eager import test from tensorflow.python.framework import test_util -@test_util.with_c_api class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): def _get_distribution_strategy(self): diff --git a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py index 713494d603b855be2863af9f24ab98d4cf048042..a0b452fc2d445d1cf7dbf5e8fe0e29edef516207 100644 --- a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py +++ b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py @@ -44,7 +44,6 @@ class CanonicalizeVariableNameTest(test.TestCase): self.assertEquals("foo_a", self._canonicalize("foo_a")) -@test_util.with_c_api class SharedVariableCreatorTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py index 0db0b59fcacee2785eb8191bb84ed5216a79b081..d1fdb3279cf2a7cba6e2282d58eedccf38bd38a3 100644 --- a/tensorflow/contrib/distribute/python/single_loss_example.py +++ b/tensorflow/contrib/distribute/python/single_loss_example.py @@ -22,6 +22,7 @@ from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.distribute.python import step_fn from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.layers import normalization from tensorflow.python.ops import array_ops @@ -59,7 +60,7 @@ def minimize_loss_example(optimizer_fn, # TODO(isaprykin): map_and_batch with drop_remainder causes shapes to be # fully defined for TPU. Remove this when XLA supports dynamic shapes. return dataset.apply( - batching.map_and_batch(lambda x: x, batch_size=2, drop_remainder=True)) + batching.map_and_batch(lambda x: x, batch_size=1, drop_remainder=True)) # An Optimizer instance is created either outside or inside model_fn. outer_optimizer = None @@ -68,11 +69,10 @@ def minimize_loss_example(optimizer_fn, layer = core.Dense(1, use_bias=use_bias) - def model_fn(xs): + def model_fn(x): """A very simple model written by the user.""" def loss_fn(): - x = math_ops.reduce_mean(xs, keepdims=True) y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) return y * y @@ -89,7 +89,8 @@ def minimize_loss_example(optimizer_fn, def batchnorm_example(optimizer_fn, batch_per_epoch=1, momentum=0.9, - renorm=False): + renorm=False, + update_ops_in_tower_mode=False): """Example of non-distribution-aware legacy code with batch normalization.""" def dataset_fn(): @@ -103,12 +104,19 @@ def batchnorm_example(optimizer_fn, optimizer = optimizer_fn() batchnorm = normalization.BatchNormalization( renorm=renorm, momentum=momentum, fused=False) + layer = core.Dense(1, use_bias=False) def model_fn(x): + """A model that uses batchnorm.""" def loss_fn(): - y = math_ops.reduce_sum(batchnorm(x, training=True), axis=1) - loss = math_ops.reduce_mean(y - constant_op.constant(1.)) + y = batchnorm(x, training=True) + with ops.control_dependencies( + ops.get_collection(ops.GraphKeys.UPDATE_OPS) + if update_ops_in_tower_mode else []): + loss = math_ops.reduce_mean( + math_ops.reduce_sum(layer(y)) - constant_op.constant(1.)) + # `x` and `y` will be fetched by the gradient computation, but not `loss`. return loss # Callable loss. diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index a7e4fe80f3e65907fa4b48c5fe0fcfd422ba033f..75441786a615fc0d87b4c4b0b45b9384d678c1d3 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -33,7 +33,6 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.util import nest -# TODO(isaprykin): Consider whether inheriting is really appropriate. class TPUStrategy(one_device_strategy.OneDeviceStrategy): """Experimental TPU distribution strategy implementation.""" @@ -73,7 +72,6 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): def infeed_input(i): """Get input, split it and then enqueue.""" iteration_inputs = [f.get(i) for f in feeds()] - infeed_inputs = [[inputs_per_core[core_id] for inputs_per_core in iteration_inputs] for core_id in range(self._num_cores_per_host)] @@ -117,3 +115,14 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): iterate_on_tpu, [], num_shards=self._num_cores_per_host) return control_flow_ops.group(tpu_result, enqueue_ops) + + def _reduce(self, method_string, value, destinations): + del destinations # TPU is graph mode only. Rely on implicit Send/Recv. + if method_string == 'mean': + # TODO(jhseu): Revisit once we support model-parallelism. + value *= (1. / self._num_cores_per_host) + return tpu_ops.cross_replica_sum(value) + + @property + def num_towers(self): + return self._num_cores_per_host diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index aaf177d07ead6978db45277252540e3b329f2bc3..49b4e24daa4ffe417712bc854aa29995d5afc408 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -34,10 +34,11 @@ from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.training import checkpointable +from tensorflow.python.ops import math_ops from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util import nest @@ -60,7 +61,7 @@ class DistributedValues(object): else: device = distribute_lib.get_update_device() if device is None: - device = device_util.current() + return self._get_cross_tower() device = device_util.canonicalize(device) try: return self._index[device] @@ -231,12 +232,6 @@ class DistributedVariable(DistributedDelegate): self._primary_var.op.type) return self.get().op - def _as_graph_element(self): - # pylint: disable=protected-access - if distribute_lib.get_cross_tower_context(): - return self._primary_var._as_graph_element() - return self.get()._as_graph_element() - def _should_act_as_resource_variable(self): """Pass resource_variable_ops.is_resource_variable check.""" pass @@ -320,6 +315,18 @@ class MirroredVariable(DistributedVariable, Mirrored, def assign(self, *args, **kwargs): return self.get(device=_get_update_device()).assign(*args, **kwargs) + def _get_cross_tower(self): + device = device_util.canonicalize(device_util.current()) + if device in self._index: + return array_ops.identity(self._index[device]) + return array_ops.identity(self._primary_var) + + def _as_graph_element(self): + # pylint: disable=protected-access + if distribute_lib.get_cross_tower_context(): + return self._primary_var._as_graph_element() + return self.get()._as_graph_element() + def _gather_saveables_for_checkpoint(self): """Overrides CheckpointableBase method. @@ -364,6 +371,12 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): for d, v in six.iteritems(self._tower_local_variable._index)]) # pylint: disable=protected-access +def _assert_tower_context(): + if not distribute_lib.get_tower_context(): + raise RuntimeError( + "Tower-local variables may only be assigned in a tower context.") + + class TowerLocalVariable(DistributedVariable, PerDevice, checkpointable.CheckpointableBase): """Holds a map from device to variables whose values are reduced on save.""" @@ -374,18 +387,35 @@ class TowerLocalVariable(DistributedVariable, PerDevice, super(TowerLocalVariable, self).__init__(index) def assign_sub(self, *args, **kwargs): + _assert_tower_context() return self.get().assign_sub(*args, **kwargs) def assign_add(self, *args, **kwargs): + _assert_tower_context() return self.get().assign_add(*args, **kwargs) def assign(self, *args, **kwargs): + _assert_tower_context() return self.get().assign(*args, **kwargs) @property def reduce_method(self): return self._reduce_method + def _get_cross_tower(self): + all_components = tuple(self._index.values()) + # TODO(josh11b): Use a strategy-specific method. + total = math_ops.add_n(all_components) + if self._reduce_method == "mean": + return total * (1./ len(all_components)) + return total + + def _as_graph_element(self): + # pylint: disable=protected-access + if distribute_lib.get_cross_tower_context(): + return self._get_cross_tower() + return self.get()._as_graph_element() + def _gather_saveables_for_checkpoint(self): """Overrides CheckpointableBase method. @@ -672,11 +702,12 @@ class MultiWorkerDataset(object): return MultiWorkerDataIterator(iterators, self._worker_device_map) -class PerIteration(object): - """Holds input for multiple iterations at once.""" +class _PerKey(object): + """Holds data associated by keys.""" - def __init__(self, index): - self._index = index + def __init__(self, *index): + # pylint: disable=protected-access + self._index = list(index) def get(self, iteration): return array_ops.gather(self._index, iteration) @@ -687,6 +718,24 @@ class PerIteration(object): def get_dtype(self): return self._index[-1][-1].dtype + def __str__(self): + return "%s:%s" % (self.__class__.__name__, self._index) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._index) + + +class PerIteration(_PerKey): + """Holds input for multiple iterations at once.""" + + def __init__(self, *index): + # pylint: disable=protected-access + super(PerIteration, self).__init__(*[batch._index for batch in index]) + + +class Batches(_PerKey): + pass + class MultiIterator(object): """Iterator that returns results of multiple get_next()s.""" @@ -697,11 +746,31 @@ class MultiIterator(object): self._batches_per_iteration = batches_per_iteration def get_next(self, name=None): - return PerIteration([[ - self._dataset_iterator.get_next(name=name) - for _ in range(self._batches_per_iteration) - ] - for _ in range(self._iterations)]) + """Return PerIteration with `iterations x batches_per_iteration` inputs.""" + data = [] + for _ in range(self._batches_per_iteration): + batch = [] + for _ in range(self._iterations): + batch.append(self._dataset_iterator.get_next(name=name)) + data.append(batch) + + # Here is an example. Suppose each get_next returns a tuple of two tensors. + # For 3 `iterations` and 2 `batches_per_iteration`, the `data` is: + # [[(a,z), (b,y), (c,x)], [(A,Z), (B,Y), (C,X)]] + # + # After the first `map_structure` it gets transformed to: + # [(Batches(a, A), Batches(z, Z)), + # (Batches(b, B), Batches(y, Y)), + # (Batches(c, C), Batches(x, X))] + # + # After the second `map_structure` it gets transformed to a tuple of: + # (PerIteration([Batches(a, A), Batches(b, B), Batches(c, C)]), + # PerIteration([Batches(z, Z), Batches(y, Y), Batches(x, X)])) + + data = nest.map_structure(Batches, *data) + data = nest.map_structure(PerIteration, *data) + + return data @property def initializer(self): diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 9aeef9fa3e86f25ba2544236fd802c7162f4e40e..1c95758d96aba47e9581dde6411763e98b99a968 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -42,7 +42,6 @@ from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import nest -@test_util.with_c_api class DistributedValuesTest(test.TestCase): def testGetEager(self): @@ -81,7 +80,6 @@ class DistributedValuesTest(test.TestCase): v = values.DistributedValues({"/device:cpu:0": 42}) -@test_util.with_c_api class DistributedDelegateTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() @@ -164,7 +162,6 @@ def _make_mirrored(): return v, devices, mirrored -@test_util.with_c_api class RegroupAndSelectDeviceTest(test.TestCase): def _is_per_device(self, result, expected, klass=values.PerDevice): @@ -317,7 +314,6 @@ class RegroupAndSelectDeviceTest(test.TestCase): merged_estimator_spec)) -@test_util.with_c_api class PerDeviceDatasetTest(test.TestCase): config = config_pb2.ConfigProto() @@ -564,7 +560,6 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): multi_worker_iterator.get_next() -@test_util.with_c_api class MirroredVariableTest(test.TestCase): config = config_pb2.ConfigProto() @@ -741,7 +736,6 @@ def _make_tower_local(method): return v, tower_local -@test_util.with_c_api class TowerLocalVariableTest(test.TestCase): config = config_pb2.ConfigProto() diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index fad613155d8861a2508fb7aca752b10ff85d35eb..6192f04c8b695d124b498850ad430823b44fd472 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -94,7 +94,7 @@ cuda_py_test( cuda_py_test( name = "distribution_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/distribution_test.py"], additional_deps = [ ":distributions_py", @@ -337,7 +337,7 @@ cuda_py_test( cuda_py_test( name = "mvn_tril_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/mvn_tril_test.py"], additional_deps = [ ":distributions_py", @@ -372,6 +372,7 @@ cuda_py_test( "//tensorflow/python:random_ops", "//tensorflow/python:variables", ], + shard_count = 4, ) cuda_py_test( @@ -459,7 +460,7 @@ cuda_py_test( cuda_py_test( name = "batch_reshape_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/batch_reshape_test.py"], additional_deps = [ ":distributions_py", @@ -578,7 +579,7 @@ cuda_py_test( cuda_py_test( name = "wishart_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/wishart_test.py"], additional_deps = [ ":distributions_py", @@ -709,6 +710,7 @@ cuda_py_test( "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:client_testlib", ], + shard_count = 4, tags = ["noasan"], # times out, http://b/78588814 ) @@ -866,7 +868,7 @@ cuda_py_test( cuda_py_test( name = "batch_normalization_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/bijectors/batch_normalization_test.py"], additional_deps = [ ":bijectors_py", diff --git a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py index 59d549b7b80a3d80d0b8409542eb6583f645bdaa..f2bb2d3325a7cc6ec5803860600149522752a4c0 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py @@ -448,8 +448,7 @@ class _BatchReshapeTest(object): else: with self.test_session(): - with self.assertRaisesOpError(r"`batch_shape` size must match " - r"`distributions.batch_shape` size"): + with self.assertRaisesOpError(r"Shape sizes do not match."): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, @@ -457,8 +456,13 @@ class _BatchReshapeTest(object): def test_non_positive_shape(self): dims = 2 - new_batch_shape = [-1, -2] # -1*-2=2 so will pass size check. - old_batch_shape = [2] + old_batch_shape = [4] + if self.is_static_shape: + # Unknown first dimension does not trigger size check. Note that + # any dimension < 0 is treated statically as unknown. + new_batch_shape = [-1, 0] + else: + new_batch_shape = [-2, -2] # -2 * -2 = 4, same size as the old shape. new_batch_shape_ph = ( constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape @@ -471,7 +475,7 @@ class _BatchReshapeTest(object): mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) if self.is_static_shape: - with self.assertRaisesRegexp(ValueError, r".*must be positive.*"): + with self.assertRaisesRegexp(ValueError, r".*must be >=-1.*"): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, @@ -479,7 +483,7 @@ class _BatchReshapeTest(object): else: with self.test_session(): - with self.assertRaisesOpError(r".*must be positive.*"): + with self.assertRaisesOpError(r".*must be >=-1.*"): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py index ca20442c3940664feab7526110229872a6cdc41f..dc45114b1c23b5edb78d68ad4f38f5201d265170 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test @@ -188,6 +189,15 @@ class ChainBijectorTest(test.TestCase): -np.log(6, dtype=np.float32) - np.sum(x), self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1))) + def testChainIldjWithPlaceholder(self): + chain = Chain((Exp(), Exp())) + samples = array_ops.placeholder( + dtype=np.float32, shape=[None, 10], name="samples") + ildj = chain.inverse_log_det_jacobian(samples, event_ndims=0) + self.assertTrue(ildj is not None) + with self.test_session(): + ildj.eval({samples: np.zeros([2, 10], np.float32)}) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py index 46f2c63f9b0f78b25bb1948e6ea55ab20c5cfa6e..d44e49b4874a5b91f7633cd9c97dbb1a7da70f27 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py @@ -22,15 +22,12 @@ import numpy as np from tensorflow.contrib.distributions.python.ops.bijectors.reshape import Reshape from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite from tensorflow.python.platform import test -@test_util.with_c_api class _ReshapeBijectorTest(object): """Base class for testing the reshape transformation. @@ -265,7 +262,6 @@ class _ReshapeBijectorTest(object): raise NotImplementedError("Subclass failed to implement `build_shapes`.") -@test_util.with_c_api class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): def build_shapes(self, shape_in, shape_out): @@ -305,21 +301,13 @@ class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): bijector, x, y, event_ndims=2, rtol=1e-6, atol=0) def testInvalidDimensionsOpError(self): - if ops._USE_C_API: - error_message = "Invalid value in tensor used for shape: -2" - else: - error_message = "elements must be either positive integers or `-1`." - self._testInvalidDimensionsOpError(error_message) + self._testInvalidDimensionsOpError( + "Invalid value in tensor used for shape: -2") def testInputOutputMismatchOpError(self): - if ops._USE_C_API: - error_message = "Cannot reshape a tensor with" - else: - error_message = "Input to reshape is a tensor with" - self._testInputOutputMismatchOpError(error_message) + self._testInputOutputMismatchOpError("Cannot reshape a tensor with") -@test_util.with_c_api class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest): def build_shapes(self, shape_in, shape_out): @@ -341,7 +329,6 @@ class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest): self._testInputOutputMismatchOpError("Input to reshape is a tensor with") -@test_util.with_c_api class ReshapeBijectorTestDynamicNdims(test.TestCase, _ReshapeBijectorTest): def build_shapes(self, shape_in, shape_out): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py index 7435bcbc684c1660a648cef4ab30c888723853f8..b003526392709b61e9cc46e0ff8e5fa78edc0568 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py @@ -131,8 +131,8 @@ class MultivariateNormalFullCovarianceTest(test.TestCase): return mu, sigma def testKLBatch(self): - batch_shape = (2,) - event_shape = (3,) + batch_shape = [2] + event_shape = [3] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) @@ -156,6 +156,33 @@ class MultivariateNormalFullCovarianceTest(test.TestCase): self.assertAllClose(expected_kl_0, kl_v[0]) self.assertAllClose(expected_kl_1, kl_v[1]) + def testKLBatchBroadcast(self): + batch_shape = [2] + event_shape = [3] + with self.test_session(): + mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) + # No batch shape. + mu_b, sigma_b = self._random_mu_and_sigma([], event_shape) + mvn_a = ds.MultivariateNormalFullCovariance( + loc=mu_a, + covariance_matrix=sigma_a, + validate_args=True) + mvn_b = ds.MultivariateNormalFullCovariance( + loc=mu_b, + covariance_matrix=sigma_b, + validate_args=True) + + kl = ds.kl_divergence(mvn_a, mvn_b) + self.assertEqual(batch_shape, kl.get_shape()) + + kl_v = kl.eval() + expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :], + mu_b, sigma_b) + expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :], + mu_b, sigma_b) + self.assertAllClose(expected_kl_0, kl_v[0]) + self.assertAllClose(expected_kl_1, kl_v[1]) + def _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b): """Non-batch KL for N(mu_a, sigma_a), N(mu_b, sigma_b).""" diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py index 685f32883dae5b8513badeb05e1508cd611d6e93..b556d06123800f22f5d9a90dd18f3c745aec90a1 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py @@ -235,8 +235,8 @@ class MultivariateNormalTriLTest(test.TestCase): return mu, sigma def testKLNonBatch(self): - batch_shape = () - event_shape = (2,) + batch_shape = [] + event_shape = [2] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) @@ -257,8 +257,8 @@ class MultivariateNormalTriLTest(test.TestCase): self.assertAllClose(expected_kl, kl_v) def testKLBatch(self): - batch_shape = (2,) - event_shape = (3,) + batch_shape = [2] + event_shape = [3] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) @@ -282,9 +282,36 @@ class MultivariateNormalTriLTest(test.TestCase): self.assertAllClose(expected_kl_0, kl_v[0]) self.assertAllClose(expected_kl_1, kl_v[1]) + def testKLBatchBroadcast(self): + batch_shape = [2] + event_shape = [3] + with self.test_session(): + mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) + # No batch shape. + mu_b, sigma_b = self._random_mu_and_sigma([], event_shape) + mvn_a = ds.MultivariateNormalTriL( + loc=mu_a, + scale_tril=np.linalg.cholesky(sigma_a), + validate_args=True) + mvn_b = ds.MultivariateNormalTriL( + loc=mu_b, + scale_tril=np.linalg.cholesky(sigma_b), + validate_args=True) + + kl = ds.kl_divergence(mvn_a, mvn_b) + self.assertEqual(batch_shape, kl.get_shape()) + + kl_v = kl.eval() + expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :], + mu_b, sigma_b) + expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :], + mu_b, sigma_b) + self.assertAllClose(expected_kl_0, kl_v[0]) + self.assertAllClose(expected_kl_1, kl_v[1]) + def testKLTwoIdenticalDistributionsIsZero(self): - batch_shape = (2,) - event_shape = (3,) + batch_shape = [2] + event_shape = [3] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mvn_a = ds.MultivariateNormalTriL( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py index 968057331787059240110b90545f70c0ab128aa8..b91a610acf1a9094d612504d63030b3bffb873ac 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py @@ -65,6 +65,16 @@ class SeedStreamTest(test.TestCase): self.assertAllUnique( outputs + [strm2() for _ in range(50)] + [strm3() for _ in range(50)]) + def testInitFromOtherSeedStream(self): + strm1 = seed_stream.SeedStream(seed=4, salt="salt") + strm2 = seed_stream.SeedStream(strm1, salt="salt") + strm3 = seed_stream.SeedStream(strm1, salt="another salt") + out1 = [strm1() for _ in range(50)] + out2 = [strm2() for _ in range(50)] + out3 = [strm3() for _ in range(50)] + self.assertAllEqual(out1, out2) + self.assertAllUnique(out1 + out3) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py index 88ed0127841093cc1a1168d988f14e7bb0277b12..d813831bef803a22c095d9c98e7163aa4861a15d 100644 --- a/tensorflow/contrib/distributions/python/ops/autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py @@ -144,7 +144,7 @@ class Autoregressive(distribution_lib.Distribution): `distribution_fn(sample0).event_shape.num_elements()` are both `None`. ValueError: if `num_steps < 1`. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name) as name: self._distribution_fn = distribution_fn self._sample0 = sample0 diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py index bf5590cd552a915a3ecfc1912ee530baf79665a6..c709318f76552e1188f735f5bafff4be0537baed 100644 --- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py +++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.distributions import util as distribution_util __all__ = [ @@ -41,9 +42,6 @@ class BatchReshape(distribution_lib.Distribution): This "meta-distribution" reshapes the batch dimensions of another distribution. - Note: Unlike `tf.reshape`, the `BatchReshape` distribution does not support - `-1` for flattening. - #### Examples ```python @@ -51,7 +49,7 @@ class BatchReshape(distribution_lib.Distribution): dtype = np.float32 dims = 2 - new_batch_shape = [1, 2, 3] + new_batch_shape = [1, 2, -1] old_batch_shape = [6] scale = np.ones(old_batch_shape + [dims], dtype) @@ -85,8 +83,9 @@ class BatchReshape(distribution_lib.Distribution): Args: distribution: The base distribution instance to reshape. Typically an instance of `Distribution`. - batch_shape: Positive `int`-like vector-shaped `Tensor` representing the - new shape of the batch dimensions. + batch_shape: Positive `int`-like vector-shaped `Tensor` representing + the new shape of the batch dimensions. Up to one dimension may contain + `-1`, meaning the remainder of the batch size. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -104,31 +103,28 @@ class BatchReshape(distribution_lib.Distribution): ValueError: if `batch_shape` size is not the same as a `distribution.batch_shape` size. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() name = name or "BatchReshape" + distribution.name - self._distribution = distribution with ops.name_scope(name, values=[batch_shape]) as name: - self._batch_shape_ = ops.convert_to_tensor( - batch_shape, - dtype=dtypes.int32, - name="batch_shape") - self._batch_shape_static = tensor_util.constant_value(self._batch_shape_) - if self._batch_shape_static is not None: - self._batch_shape_static = np.int32(self._batch_shape_static) - self._runtime_assertions = validate_init_args( - self._distribution, - self._batch_shape_, - validate_args, - self._batch_shape_static) + # The unexpanded batch shape may contain up to one dimension of -1. + self._batch_shape_unexpanded = ops.convert_to_tensor( + batch_shape, dtype=dtypes.int32, name="batch_shape") + validate_init_args_statically(distribution, self._batch_shape_unexpanded) + batch_shape, batch_shape_static, runtime_assertions = calculate_reshape( + distribution.batch_shape_tensor(), self._batch_shape_unexpanded, + validate_args) + self._distribution = distribution + self._batch_shape_ = batch_shape + self._batch_shape_static = batch_shape_static + self._runtime_assertions = runtime_assertions super(BatchReshape, self).__init__( - dtype=self._distribution.dtype, - reparameterization_type=self._distribution.reparameterization_type, + dtype=distribution.dtype, + reparameterization_type=distribution.reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( - [self._batch_shape_] + - self._distribution._graph_parents), # pylint: disable=protected-access + [self._batch_shape_unexpanded] + distribution._graph_parents), # pylint: disable=protected-access name=name) @property @@ -140,7 +136,7 @@ class BatchReshape(distribution_lib.Distribution): return array_ops.identity(self._batch_shape_) def _batch_shape(self): - return tensor_shape.TensorShape(self._batch_shape_static) + return self._batch_shape_static def _event_shape_tensor(self): with ops.control_dependencies(self._runtime_assertions): @@ -152,11 +148,13 @@ class BatchReshape(distribution_lib.Distribution): def _sample_n(self, n, seed=None): with ops.control_dependencies(self._runtime_assertions): x = self.distribution.sample(sample_shape=n, seed=seed) - new_shape = array_ops.concat([ - [n], - self.batch_shape_tensor(), - self.event_shape_tensor(), - ], axis=0) + new_shape = array_ops.concat( + [ + [n], + self._batch_shape_unexpanded, + self.event_shape_tensor(), + ], + axis=0) return array_ops.reshape(x, new_shape) def _log_prob(self, x): @@ -213,9 +211,9 @@ class BatchReshape(distribution_lib.Distribution): event_ndims = (array_ops.size(self.event_shape_tensor()) if self.event_shape.ndims is None else self.event_shape.ndims) - batch_ndims = (array_ops.size(self.batch_shape_tensor()) - if self.batch_shape.ndims is None - else self.batch_shape.ndims) + batch_ndims = ( + array_ops.size(self._batch_shape_unexpanded) + if self.batch_shape.ndims is None else self.batch_shape.ndims) sample_ndims = x_ndims - batch_ndims - event_ndims if isinstance(sample_ndims, int): static_sample_shape = x.shape[:sample_ndims] @@ -238,10 +236,11 @@ class BatchReshape(distribution_lib.Distribution): self.event_shape_tensor(), ], axis=0) result = fn(array_ops.reshape(x, old_shape)) - new_shape = array_ops.concat([ - sample_shape, - self.batch_shape_tensor(), - ], axis=0) + new_shape = array_ops.concat( + [ + sample_shape, + self._batch_shape_unexpanded, + ], axis=0) result = array_ops.reshape(result, new_shape) if (static_sample_shape.ndims is not None and self.batch_shape.ndims is not None): @@ -261,8 +260,7 @@ class BatchReshape(distribution_lib.Distribution): if static_event_shape_list is None: static_event_shape_list = [self.event_shape] new_shape = array_ops.concat( - [self.batch_shape_tensor()] + event_shape_list, - axis=0) + [self._batch_shape_unexpanded] + event_shape_list, axis=0) result = array_ops.reshape(fn(), new_shape) if (self.batch_shape.ndims is not None and self.event_shape.ndims is not None): @@ -281,9 +279,9 @@ class BatchReshape(distribution_lib.Distribution): event_ndims = (array_ops.size(self.event_shape_tensor()) if self.event_shape.ndims is None else self.event_shape.ndims) - batch_ndims = (array_ops.size(self.batch_shape_tensor()) - if self.batch_shape.ndims is None - else self.batch_shape.ndims) + batch_ndims = ( + array_ops.size(self._batch_shape_unexpanded) + if self.batch_shape.ndims is None else self.batch_shape.ndims) expected_batch_event_ndims = batch_ndims + event_ndims if (isinstance(x_ndims, int) and @@ -355,62 +353,56 @@ class BatchReshape(distribution_lib.Distribution): return runtime_assertions -def validate_init_args( - distribution, - batch_shape, - validate_args, - batch_shape_static): +def calculate_reshape(original_shape, new_shape, validate=False, name=None): + """Calculates the reshaped dimensions (replacing up to one -1 in reshape).""" + batch_shape_static = tensor_util.constant_value_as_shape(new_shape) + if batch_shape_static.is_fully_defined(): + return np.int32(batch_shape_static.as_list()), batch_shape_static, [] + with ops.name_scope(name, "calculate_reshape", [original_shape, new_shape]): + original_size = math_ops.reduce_prod(original_shape) + implicit_dim = math_ops.equal(new_shape, -1) + size_implicit_dim = ( + original_size // math_ops.maximum(1, -math_ops.reduce_prod(new_shape))) + new_ndims = array_ops.shape(new_shape) + expanded_new_shape = array_ops.where( # Assumes exactly one `-1`. + implicit_dim, array_ops.fill(new_ndims, size_implicit_dim), new_shape) + validations = [] if not validate else [ + check_ops.assert_rank( + original_shape, 1, message="Original shape must be a vector."), + check_ops.assert_rank( + new_shape, 1, message="New shape must be a vector."), + check_ops.assert_less_equal( + math_ops.count_nonzero(implicit_dim, dtype=dtypes.int32), + 1, + message="At most one dimension can be unknown."), + check_ops.assert_positive( + expanded_new_shape, message="Shape elements must be >=-1."), + check_ops.assert_equal( + math_ops.reduce_prod(expanded_new_shape), + original_size, + message="Shape sizes do not match."), + ] + return expanded_new_shape, batch_shape_static, validations + + +def validate_init_args_statically(distribution, batch_shape): """Helper to __init__ which makes or raises assertions.""" - with ops.name_scope(name="validate_init_args", - values=[batch_shape] + distribution._graph_parents): # pylint: disable=protected-access - runtime_assertions = [] - - if batch_shape.shape.ndims is not None: - if batch_shape.shape.ndims != 1: - raise ValueError("`batch_shape` must be a vector " - "(saw rank: {}).".format( - batch_shape.shape.ndims)) - elif validate_args: - runtime_assertions += [ - check_ops.assert_rank( - batch_shape, - 1, - message="`batch_shape` must be a vector.", - name="assert_batch_shape_is_vector"), - ] - - batch_size_static = np.prod(batch_shape_static) - dist_batch_size_static = ( - None if not distribution.batch_shape.is_fully_defined() - else np.prod(distribution.batch_shape).value) - - if batch_size_static is not None and dist_batch_size_static is not None: - if batch_size_static != dist_batch_size_static: - raise ValueError("`batch_shape` size ({}) must match " - "`distribution.batch_shape` size ({}).".format( - batch_size_static, - dist_batch_size_static)) - elif validate_args: - runtime_assertions += [ - check_ops.assert_equal( - math_ops.reduce_prod(batch_shape), - math_ops.reduce_prod(distribution.batch_shape_tensor()), - message=("`batch_shape` size must match " - "`distributions.batch_shape` size."), - name="assert_batch_size"), - ] - - if batch_shape_static is not None: - if np.any(batch_shape_static < 1): - raise ValueError("`batch_shape` elements must be positive " - "(i.e., larger than zero).") - elif validate_args: - runtime_assertions += [ - check_ops.assert_positive( - batch_shape, - message=("`batch_shape` elements must be positive " - "(i.e., larger than zero)."), - name="assert_batch_shape_positive") - ] - - return runtime_assertions + if batch_shape.shape.ndims is not None: + if batch_shape.shape.ndims != 1: + raise ValueError("`batch_shape` must be a vector " + "(saw rank: {}).".format(batch_shape.shape.ndims)) + + batch_shape_static = tensor_util.constant_value_as_shape(batch_shape) + batch_size_static = batch_shape_static.num_elements() + dist_batch_size_static = distribution.batch_shape.num_elements() + + if batch_size_static is not None and dist_batch_size_static is not None: + if batch_size_static != dist_batch_size_static: + raise ValueError("`batch_shape` size ({}) must match " + "`distribution.batch_shape` size ({}).".format( + batch_size_static, dist_batch_size_static)) + + if batch_shape_static.dims is not None: + if any( + dim.value is not None and dim.value < 1 for dim in batch_shape_static): + raise ValueError("`batch_shape` elements must be >=-1.") diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index 85ad23e4133ef09051cdc8b45e489caeea90fbb3..b158a51bb022b5e2ea3afda74e97b9dc131665a6 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -20,10 +20,9 @@ from __future__ import print_function import itertools -from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector @@ -36,15 +35,6 @@ def _use_static_shape(input_tensor, ndims): return input_tensor.shape.is_fully_defined() and isinstance(ndims, int) -def _maybe_get_event_ndims_statically(event_ndims): - static_event_ndims = (event_ndims if isinstance(event_ndims, int) - else tensor_util.constant_value(event_ndims)) - if static_event_ndims is not None: - return static_event_ndims - - return event_ndims - - def _compute_min_event_ndims(bijector_list, compute_forward=True): """Computes the min_event_ndims associated with the give list of bijectors. @@ -238,13 +228,13 @@ class Chain(bijector.Bijector): return y def _inverse_log_det_jacobian(self, y, **kwargs): - ildj = constant_op.constant( - 0., dtype=y.dtype.base_dtype, name="inverse_log_det_jacobian") + y = ops.convert_to_tensor(y, name="y") + ildj = math_ops.cast(0., dtype=y.dtype.base_dtype) if not self.bijectors: return ildj - event_ndims = _maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_event_ndims_statically( self.inverse_min_event_ndims) if _use_static_shape(y, event_ndims): @@ -258,11 +248,12 @@ class Chain(bijector.Bijector): if _use_static_shape(y, event_ndims): event_shape = b.inverse_event_shape(event_shape) - event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims) + event_ndims = self._maybe_get_event_ndims_statically( + event_shape.ndims) else: event_shape = b.inverse_event_shape_tensor(event_shape) - event_ndims = _maybe_get_event_ndims_statically( - array_ops.rank(event_shape)) + event_ndims = self._maybe_get_event_ndims_statically( + array_ops.size(event_shape)) y = b.inverse(y, **kwargs.get(b.name, {})) return ildj @@ -274,13 +265,12 @@ class Chain(bijector.Bijector): def _forward_log_det_jacobian(self, x, **kwargs): x = ops.convert_to_tensor(x, name="x") - fldj = constant_op.constant( - 0., dtype=x.dtype, name="inverse_log_det_jacobian") + fldj = math_ops.cast(0., dtype=x.dtype.base_dtype) if not self.bijectors: return fldj - event_ndims = _maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_event_ndims_statically( self.forward_min_event_ndims) if _use_static_shape(x, event_ndims): @@ -293,13 +283,21 @@ class Chain(bijector.Bijector): x, event_ndims=event_ndims, **kwargs.get(b.name, {})) if _use_static_shape(x, event_ndims): event_shape = b.forward_event_shape(event_shape) - event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims) + event_ndims = self._maybe_get_event_ndims_statically(event_shape.ndims) else: event_shape = b.forward_event_shape_tensor(event_shape) - event_ndims = _maybe_get_event_ndims_statically( - array_ops.rank(event_shape)) + event_ndims = self._maybe_get_event_ndims_statically( + array_ops.size(event_shape)) x = b.forward(x, **kwargs.get(b.name, {})) return fldj + def _maybe_get_event_ndims_statically(self, event_ndims): + event_ndims_ = super(Chain, self)._maybe_get_event_ndims_statically( + event_ndims) + if event_ndims_ is None: + return event_ndims + return event_ndims_ + + diff --git a/tensorflow/contrib/distributions/python/ops/binomial.py b/tensorflow/contrib/distributions/python/ops/binomial.py index 12d16031783b78dc3ea6273af77c1eaeb77ca94e..24b26bf124c78c8320b9a6bc3b900e6c7a93f5e4 100644 --- a/tensorflow/contrib/distributions/python/ops/binomial.py +++ b/tensorflow/contrib/distributions/python/ops/binomial.py @@ -163,7 +163,7 @@ class Binomial(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[total_count, logits, probs]) as name: self._total_count = self._maybe_assert_valid_total_count( ops.convert_to_tensor(total_count, name="total_count"), diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py index daacfe657fe154dce8d0db98894fe8b73546c476..f5ffdd873124d6626dca26f603592bd0b030d7b3 100644 --- a/tensorflow/contrib/distributions/python/ops/cauchy.py +++ b/tensorflow/contrib/distributions/python/ops/cauchy.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import util as distribution_util __all__ = [ "Cauchy", @@ -120,7 +121,7 @@ class Cauchy(distribution.Distribution): Raises: TypeError: if `loc` and `scale` have different `dtype`. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py index c77c5fd20895a6220604d76a95a152a22cd3d914..08cdc1582892cc7d308bd60f082dde082704f57f 100644 --- a/tensorflow/contrib/distributions/python/ops/chi2.py +++ b/tensorflow/contrib/distributions/python/ops/chi2.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import gamma +from tensorflow.python.ops.distributions import util as distribution_util __all__ = [ @@ -83,7 +84,7 @@ class Chi2(gamma.Gamma): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() # Even though all stats of chi2 are defined for valid parameters, this is # not true in the parent class "gamma." therefore, passing # allow_nan_stats=True @@ -119,7 +120,7 @@ class Chi2WithAbsDf(Chi2): validate_args=False, allow_nan_stats=True, name="Chi2WithAbsDf"): - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[df]) as name: super(Chi2WithAbsDf, self).__init__( df=math_ops.floor( diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py index a42350430e98515e521ce357bf5a87ff2daefedc..6d7d6d307bd0f815344c8a0e347f45ae11ba6462 100644 --- a/tensorflow/contrib/distributions/python/ops/deterministic.py +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import util as distribution_util __all__ = [ "Deterministic", @@ -86,7 +87,7 @@ class _BaseDeterministic(distribution.Distribution): Raises: ValueError: If `loc` is a scalar. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[loc, atol, rtol]) as name: loc = ops.convert_to_tensor(loc, name="loc") if is_vector and validate_args: diff --git a/tensorflow/contrib/distributions/python/ops/geometric.py b/tensorflow/contrib/distributions/python/ops/geometric.py index 53dd42f4c83fcea0ec5b1374c8e3109ebe1dd127..446cff6ec242f25178fed0c6a424791fa9f176ad 100644 --- a/tensorflow/contrib/distributions/python/ops/geometric.py +++ b/tensorflow/contrib/distributions/python/ops/geometric.py @@ -85,7 +85,7 @@ class Geometric(distribution.Distribution): name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index 2c261073ee16462599740cb241108bfe08c773ec..ed9ea6f4f3ffe18fb6bf1e0a7d57728d010e0f01 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import util as distribution_util class _Gumbel(distribution.Distribution): @@ -124,7 +125,7 @@ class _Gumbel(distribution.Distribution): Raises: TypeError: if loc and scale are different dtypes. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py index d0df2befd6e46ca93e5a0b5d1cb5407d6719c7f2..7e12767f6d8f6c61565ecf266d3b222de68c0e40 100644 --- a/tensorflow/contrib/distributions/python/ops/half_normal.py +++ b/tensorflow/contrib/distributions/python/ops/half_normal.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import special_math +from tensorflow.python.ops.distributions import util as distribution_util __all__ = [ @@ -105,7 +106,7 @@ class HalfNormal(distribution.Distribution): if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index fbde55ef310de1d926b8ddd503499fbed4809373..fa89fff3b7b2f8266a44c446a0c9807790b3aed8 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib from tensorflow.python.ops.distributions import kullback_leibler +from tensorflow.python.ops.distributions import util as distribution_util class Independent(distribution_lib.Distribution): @@ -116,7 +117,7 @@ class Independent(distribution_lib.Distribution): ValueError: if `reinterpreted_batch_ndims` exceeds `distribution.batch_ndims` """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() name = name or "Independent" + distribution.name self._distribution = distribution with ops.name_scope(name) as name: diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 502bd4f493337bab180129cd0ddfaf5a76a0ca4e..85e8e10466038e5e55ef4b754f82c0c2c2543b6d 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -125,7 +125,7 @@ class InverseGamma(distribution.Distribution): Raises: TypeError: if `concentration` and `rate` are different dtypes. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[concentration, rate]) as name: with ops.control_dependencies([ check_ops.assert_positive(concentration), @@ -280,7 +280,7 @@ class InverseGammaWithSoftplusConcentrationRate(InverseGamma): validate_args=False, allow_nan_stats=True, name="InverseGammaWithSoftplusConcentrationRate"): - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[concentration, rate]) as name: super(InverseGammaWithSoftplusConcentrationRate, self).__init__( concentration=nn.softplus(concentration, diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py index c83b5bc2e3a8c56f5c52d063a7d0d399be1c1870..0103283259b0526b5a108ea1836f95709eedc067 100644 --- a/tensorflow/contrib/distributions/python/ops/logistic.py +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import util as distribution_util class Logistic(distribution.Distribution): @@ -119,7 +120,7 @@ class Logistic(distribution.Distribution): Raises: TypeError: if loc and scale are different dtypes. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index 2ef294af2e8bc9beff735ec2e0fd6b619ce96176..d54f30dc634ab5c8aa82066056266747b63eec21 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -116,7 +116,7 @@ class Mixture(distribution.Distribution): matching static batch shapes, or all components do not have matching static event shapes. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() if not isinstance(cat, categorical.Categorical): raise TypeError("cat must be a Categorical distribution, but saw: %s" % cat) diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index 0b1301e551728f74bb0048d2dcf3c356ae110c75..c7c90cf875484a1753577227bf22de878d00a502 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -130,7 +130,7 @@ class MixtureSameFamily(distribution.Distribution): ValueError: if `mixture_distribution` categories does not equal `components_distribution` rightmost batch shape. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name) as name: self._mixture_distribution = mixture_distribution self._components_distribution = components_distribution diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py index e3236c2db93695a5e007bba9a1414773f3935f2e..cad398582b9c939e8e96cf498638869ccd3701bd 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py @@ -193,7 +193,7 @@ class MultivariateNormalDiag( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name) as name: with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier]): @@ -224,7 +224,7 @@ class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag): validate_args=False, allow_nan_stats=True, name="MultivariateNormalDiagWithSoftplusScale"): - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[scale_diag]) as name: super(MultivariateNormalDiagWithSoftplusScale, self).__init__( loc=loc, diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index 2f6a6f198cbcfbdcbd0993d3074ddde1c389585f..1c11594df3ad2612dd8746bb8785d86390b69937 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -215,7 +215,7 @@ class MultivariateNormalDiagPlusLowRank( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() def _convert_to_tensor(x, name): return None if x is None else ops.convert_to_tensor(x, name=name) with ops.name_scope(name) as name: diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index 5d06a396fe7a3b87cabb9c3081da45246854089f..47d7d13cf357f1ac657641420602c92eefdad197 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops.distributions import util as distribution_util __all__ = [ @@ -155,7 +156,7 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): Raises: ValueError: if neither `loc` nor `covariance_matrix` are specified. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() # Convert the covariance_matrix up to a scale_tril and call MVNTriL. with ops.name_scope(name) as name: diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 44c92312c7dc758500051f89923ec9fafe850c0e..79916fef8d7b752649dcc673a84ea45ccf460905 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -170,7 +170,7 @@ class MultivariateNormalLinearOperator( ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() if scale is None: raise ValueError("Missing required `scale` parameter.") if not scale.dtype.is_floating: diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index d6f8b731cbeed5fed3b43365e7c668d0434a267e..d6b0ed994ec0a62e9b7684e7478130052a1fd300 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -179,7 +179,7 @@ class MultivariateNormalTriL( Raises: ValueError: if neither `loc` nor `scale_tril` are specified. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() def _convert_to_tensor(x, name): return None if x is None else ops.convert_to_tensor(x, name=name) if loc is None and scale_tril is None: diff --git a/tensorflow/contrib/distributions/python/ops/negative_binomial.py b/tensorflow/contrib/distributions/python/ops/negative_binomial.py index eeaf9c0a5ebc1323e137ff73f82588f6907031c7..1085c56dc86c8d45bdab2e7cecedf44663e5c408 100644 --- a/tensorflow/contrib/distributions/python/ops/negative_binomial.py +++ b/tensorflow/contrib/distributions/python/ops/negative_binomial.py @@ -90,7 +90,7 @@ class NegativeBinomial(distribution.Distribution): name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[total_count, logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py index 305b138fdc2318523ee078195213caf865d96b4d..a4b9f3b78d4fdcc328bac84623114b921b9ded49 100644 --- a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py @@ -115,7 +115,7 @@ class OneHotCategorical(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( name=name, logits=logits, probs=probs, validate_args=validate_args, diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py index a84aad6fc9372395ac021fa3aa006ddf9272e6a9..b34539402102b8f289d4eb289fcb82f4030f4e8c 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson.py +++ b/tensorflow/contrib/distributions/python/ops/poisson.py @@ -93,7 +93,7 @@ class Poisson(distribution.Distribution): TypeError: if `rate` is not a float-type. TypeError: if `log_rate` is not a float-type. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[rate]) as name: if (rate is None) == (log_rate is None): raise ValueError("Must specify exactly one of `rate` and `log_rate`.") diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index 19c99dcee92978e938a73af9be445cd098e5fe90..fe72091d7d759e54c51eb666f2ceacc8371e55fd 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -255,7 +255,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): TypeError: if `quadrature_grid` and `quadrature_probs` have different base `dtype`. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[loc, scale]) as name: if loc is not None: loc = ops.convert_to_tensor(loc, name="loc") diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py index eb94760ad71f5babaedaafd3f7990b40aaad85c2..584d2c385fced95ec496bb8dae9556e5c376b66d 100644 --- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py @@ -263,7 +263,7 @@ class QuantizedDistribution(distributions.Distribution): `Distribution` or continuous. NotImplementedError: If the base distribution does not implement `cdf`. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() values = ( list(distribution.parameters.values()) + [low, high]) diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py index 84c8d29072c2f1f3888329638c4695bccf70eab7..0362996e684fb34b15cd98a2fc40df58087fbe95 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py @@ -165,7 +165,7 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution): Raises: ValueError: If both `probs` and `logits` are passed, or if neither. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[logits, probs, temperature]) as name: with ops.control_dependencies([check_ops.assert_positive(temperature)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index 325f41e37c928ba8e81e45e63a7f7f8126bc80f8..910c430ae7f026a3ac9ce50d1d5936d4454cba41 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -162,7 +162,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[logits, probs, temperature]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( diff --git a/tensorflow/contrib/distributions/python/ops/seed_stream.py b/tensorflow/contrib/distributions/python/ops/seed_stream.py index 056d349688511e19a4fa3d58a5b3c1c8355671a3..cf505ac627b62ae0a3d1ec1ce2a237c3c2ff1b74 100644 --- a/tensorflow/contrib/distributions/python/ops/seed_stream.py +++ b/tensorflow/contrib/distributions/python/ops/seed_stream.py @@ -169,7 +169,7 @@ class SeedStream(object): and TensorFlow Probability code base. See class docstring for rationale. """ - self._seed = seed + self._seed = seed.original_seed if isinstance(seed, SeedStream) else seed self._salt = salt self._counter = 0 diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py index 03828fa61277eeaf7ce90de8023b4ed91f6cc4dc..f04dc8da39140240edbe4efb75de30e321436d55 100644 --- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -132,7 +132,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[loc, scale, skewness, tailweight]) as name: diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index af6ff8162b173015dca2d568e13d63127af7853a..cd6d7499595d88d18de339371d4a07fe780662d9 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -395,7 +395,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): ValueError: if `not distribution.is_scalar_batch`. ValueError: if `not distribution.is_scalar_event`. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[mix_loc, temperature]) as name: if not scale or len(scale) < 2: raise ValueError("Must specify list (or list-like object) of scale " diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py index e265b5d0f7c10b2782a1a8924babdca9b986f622..3465d66b30501e7aebd9904d2ae2206d628c10b7 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py @@ -175,7 +175,7 @@ class VectorExponentialDiag( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name) as name: with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier]): diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py index 89136d6760bb663b5ff86a77c5945ce900f072b9..2c31b019845d7e4558eb3047af84732a2ae03986 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py @@ -175,7 +175,7 @@ class VectorExponentialLinearOperator( ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() if scale is None: raise ValueError("Missing required `scale` parameter.") if not scale.dtype.is_floating: diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py index 8dd983b750d9b39775e570800006011f4968f7f3..6a36018d6f1b83955ef9080ec11c74c08a670075 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py @@ -210,7 +210,7 @@ class VectorLaplaceDiag( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name): with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier]): diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index ec485c95c15da2794b67d2699d2bdd9db97bb6c4..97e5c76d800acd800e34a9e66a3c5fdd7ce4f660 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -191,7 +191,7 @@ class VectorLaplaceLinearOperator( ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() if scale is None: raise ValueError("Missing required `scale` parameter.") if not scale.dtype.is_floating: diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index 1438ede26500bca4541fa9b2020ff22d4c071098..ff5ca4525700aedc88d75e391bf0c2415c2afa13 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -163,7 +163,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope( name, diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index 7e78ded9df07564126b46b6beeeccf95bf1eef94..4742f7521816d4643354017495f3380c78ac7bc2 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -175,7 +175,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() graph_parents = [df, loc, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_factor, scale_perturb_diag] with ops.name_scope(name) as name: diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index 91453fed5d279178a0e062b71dad3b0f957b11b4..f555867e7f3c2a6bc797e9b3d56da2fa434aba6f 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -107,7 +107,7 @@ class _WishartLinearOperator(distribution.Distribution): ValueError: if df < k, where scale operator event shape is `(k, k)` """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() self._cholesky_input_output_matrices = cholesky_input_output_matrices with ops.name_scope(name) as name: with ops.name_scope("init", values=[df, scale_operator]): @@ -530,7 +530,7 @@ class WishartCholesky(_WishartLinearOperator): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name, values=[scale]) as name: with ops.name_scope("init", values=[scale]): scale = ops.convert_to_tensor(scale) @@ -646,7 +646,7 @@ class WishartFull(_WishartLinearOperator): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() with ops.name_scope(name) as name: with ops.name_scope("init", values=[scale]): scale = ops.convert_to_tensor(scale) diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 99abbae03fc14f241dae27f317902f7335819037..0cc764d2208c5b061b7b836bdf57a035f52c6fcf 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -120,7 +120,6 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python:array_ops", - "//tensorflow/python:checkpointable", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -131,6 +130,7 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", "//tensorflow/python/eager:function", + "//tensorflow/python/training/checkpointable:base", ], ) diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 0783d1b5d70e502e6edd80b59f37fdd93b413e12..d7909dd5a2691a015a6afed2caa475b39ca7ebc3 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -31,7 +31,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training import checkpointable +from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.training.saver import BaseSaverBuilder _uid_counter = 0 diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index 7b123707cc3a26073088cf2c57c6211e831c19fd..68bec9aee894edd60a025ac1cf87ca3e010db842 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -37,7 +37,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops -from tensorflow.python.training import checkpointable_utils +from tensorflow.python.training.checkpointable import util as checkpointable_utils class IteratorTest(test.TestCase): diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py index b80c90902353709b7f739585291ec3b5890c27c7..cc9cf53410f641cc3303b4450e9eaa1301904a64 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py @@ -227,7 +227,7 @@ def train_one_epoch(generator, discriminator, generator_optimizer, maxval=1., seed=batch_index) - with tfe.GradientTape(persistent=True) as g: + with tf.GradientTape(persistent=True) as g: generated_images = generator(noise) tf.contrib.summary.image( 'generated_images', @@ -306,7 +306,7 @@ def main(_): if __name__ == '__main__': - tfe.enable_eager_execution() + tf.enable_eager_execution() parser = argparse.ArgumentParser() parser.add_argument( diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py index bd35e50c1f434d167c5a8c5aa7d224912523ce28..81ac05e26d23c2fc53f63d64bb28bdea6072e396 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py @@ -111,5 +111,5 @@ class MnistEagerGanBenchmark(tf.test.Benchmark): if __name__ == '__main__': - tfe.enable_eager_execution() + tf.enable_eager_execution() tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index 4e1380afb2e6e722de65c691d4fbf44621072e87..2259c20741ab689dbe0d08d32ff05fc7f8a2100d 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -119,7 +119,7 @@ def synthetic_dataset_helper(w, b, num_features, noise_level, batch_size, def main(_): - tfe.enable_eager_execution() + tf.enable_eager_execution() # Ground-truth constants. true_w = [[-2.0], [4.0], [1.0]] true_b = [0.5] diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py index e53234b51a7dccc11e548ac81a7ef070c628aa52..2bc2fc2aa9150a3181db612439d0c37c8e76d1e3 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py @@ -117,5 +117,5 @@ class EagerLinearRegressionBenchmark(tf.test.Benchmark): if __name__ == "__main__": - tfe.enable_eager_execution() + tf.enable_eager_execution() tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index 8517a3bf7b6aebf4ecd2f148d2160cfea1b1b9c0..2d51cfdeee3f0b45514af0895366417158b01614 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -36,9 +36,7 @@ def device_and_data_format(): 'channels_last') -def random_batch(batch_size, device_and_format=None): - _, data_format = device_and_format or device_and_data_format() - +def random_batch(batch_size, data_format): shape = (3, 224, 224) if data_format == 'channels_first' else (224, 224, 3) shape = (batch_size,) + shape @@ -53,7 +51,7 @@ def random_batch(batch_size, device_and_format=None): def train_one_step(model, images, labels, optimizer): - with tfe.GradientTape() as tape: + with tf.GradientTape() as tape: logits = model(images, training=True) loss = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) @@ -70,7 +68,7 @@ class ResNet50Test(tf.test.TestCase): if defun: model.call = tfe.defun(model.call) with tf.device(device), tfe.execution_mode(execution_mode): - images, _ = random_batch(2) + images, _ = random_batch(2, data_format) output = model(images, training=False) tfe.async_wait() self.assertEqual((2, 1000), output.shape) @@ -91,7 +89,7 @@ class ResNet50Test(tf.test.TestCase): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False) with tf.device(device): - images, _ = random_batch(2) + images, _ = random_batch(2, data_format) output = model(images, training=False) output_shape = ((2, 2048, 1, 1) if data_format == 'channels_first' else (2, 1, 1, 2048)) @@ -101,7 +99,7 @@ class ResNet50Test(tf.test.TestCase): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False, pooling='avg') with tf.device(device): - images, _ = random_batch(2) + images, _ = random_batch(2, data_format) output = model(images, training=False) self.assertEqual((2, 2048), output.shape) @@ -115,7 +113,7 @@ class ResNet50Test(tf.test.TestCase): name='t0').as_default(), tf.contrib.summary.always_record_summaries(): with tf.device(device), tfe.execution_mode(execution_mode): optimizer = tf.train.GradientDescentOptimizer(0.1) - images, labels = random_batch(2) + images, labels = random_batch(2, data_format) train_one_step(model, images, labels, optimizer) self.assertEqual(320, len(model.variables)) tfe.async_wait() @@ -134,7 +132,7 @@ class ResNet50Test(tf.test.TestCase): model = resnet50.ResNet50(data_format) optimizer = tf.train.GradientDescentOptimizer(0.1) with tf.device(device): - images, labels = random_batch(2) + images, labels = random_batch(2, data_format) gc.disable() # Warm up. Note that this first run does create significant amounts of # garbage to be collected. The hope is that this is a build-only effect, @@ -202,18 +200,18 @@ class ResNet50Benchmarks(tf.test.Benchmark): # which forces a sync. This is a roundabout way, yes. tf.constant(1.).cpu() - def _benchmark_eager_apply(self, label, defun=False, execution_mode=None, - device_and_format=None): + def _benchmark_eager_apply(self, label, device_and_format, defun=False, + execution_mode=None, compiled=False): with tfe.execution_mode(execution_mode): - device, data_format = device_and_format or device_and_data_format() + device, data_format = device_and_format model = resnet50.ResNet50(data_format) if defun: - model.call = tfe.defun(model.call) + model.call = tfe.defun(model.call, compiled=compiled) batch_size = 64 num_burn = 5 num_iters = 30 with tf.device(device): - images, _ = random_batch(batch_size, device_and_format) + images, _ = random_batch(batch_size, data_format) for _ in xrange(num_burn): model(images, training=False).cpu() if execution_mode: @@ -227,30 +225,34 @@ class ResNet50Benchmarks(tf.test.Benchmark): self._report(label, start, num_iters, device, batch_size, data_format) def benchmark_eager_apply_sync(self): - self._benchmark_eager_apply('eager_apply', defun=False) + self._benchmark_eager_apply('eager_apply', device_and_data_format(), + defun=False) def benchmark_eager_apply_async(self): self._benchmark_eager_apply( - 'eager_apply_async', defun=False, execution_mode=tfe.ASYNC) + 'eager_apply_async', device_and_data_format(), defun=False, + execution_mode=tfe.ASYNC) def benchmark_eager_apply_with_defun(self): - self._benchmark_eager_apply('eager_apply_with_defun', defun=True) + self._benchmark_eager_apply('eager_apply_with_defun', + device_and_data_format(), defun=True) def _benchmark_eager_train(self, label, make_iterator, + device_and_format, defun=False, execution_mode=None, - device_and_format=None): + compiled=False): with tfe.execution_mode(execution_mode): - device, data_format = device_and_format or device_and_data_format() + device, data_format = device_and_format for batch_size in self._train_batch_sizes(): - (images, labels) = random_batch(batch_size, device_and_format) + (images, labels) = random_batch(batch_size, data_format) num_burn = 3 num_iters = 10 model = resnet50.ResNet50(data_format) if defun: - model.call = tfe.defun(model.call) + model.call = tfe.defun(model.call, compiled=compiled) optimizer = tf.train.GradientDescentOptimizer(0.1) with tf.device(device): @@ -273,18 +275,21 @@ class ResNet50Benchmarks(tf.test.Benchmark): self._report(label, start, num_iters, device, batch_size, data_format) def benchmark_eager_train_sync(self): - self._benchmark_eager_train('eager_train', MockIterator, defun=False) + self._benchmark_eager_train('eager_train', MockIterator, + device_and_data_format(), defun=False) def benchmark_eager_train_async(self): self._benchmark_eager_train( 'eager_train_async', MockIterator, + device_and_data_format(), defun=False, execution_mode=tfe.ASYNC) def benchmark_eager_train_with_defun(self): self._benchmark_eager_train( - 'eager_train_with_defun', MockIterator, defun=True) + 'eager_train_with_defun', MockIterator, + device_and_data_format(), defun=True) def benchmark_eager_train_datasets(self): @@ -294,7 +299,8 @@ class ResNet50Benchmarks(tf.test.Benchmark): return tfe.Iterator(ds) self._benchmark_eager_train( - 'eager_train_dataset', make_iterator, defun=False) + 'eager_train_dataset', make_iterator, + device_and_data_format(), defun=False) def benchmark_eager_train_datasets_with_defun(self): @@ -304,7 +310,8 @@ class ResNet50Benchmarks(tf.test.Benchmark): return tfe.Iterator(ds) self._benchmark_eager_train( - 'eager_train_dataset_with_defun', make_iterator, defun=True) + 'eager_train_dataset_with_defun', make_iterator, + device_and_data_format(), defun=True) if __name__ == '__main__': diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py index 75b342ba78bd5de5c2827296f6fba01ffa86d560..b7d8395e277b526ba40ccafa323ba453a8667b62 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py @@ -67,5 +67,5 @@ class RNNColorbotTest(tf.test.TestCase): if __name__ == "__main__": - tfe.enable_eager_execution() + tf.enable_eager_execution() tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index be5d60449d7e08c99cc28e76befce56f468c77fd..74701b2f4f7448c5f6c2c1bd7c67a8ee112f4115 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -304,7 +304,7 @@ def test_model(use_cudnn_rnn): def main(_): - tfe.enable_eager_execution() + tf.enable_eager_execution() if not FLAGS.data_path: raise ValueError("Must specify --data-path") diff --git a/tensorflow/contrib/eager/python/examples/scan/BUILD b/tensorflow/contrib/eager/python/examples/scan/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..638c57d1c92c1dce0ef9e73e9a6ac2369358080b --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/scan/BUILD @@ -0,0 +1,25 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +cuda_py_test( + name = "scan_test", + size = "small", + srcs = ["scan_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "scan_graph_test", + size = "small", + srcs = ["scan_graph_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d4b8c8941ec411912f3089315d038fc4bcd049ae --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py @@ -0,0 +1,54 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit test for tf.scan under graph mode execution.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np +import tensorflow as tf + + +class ScanBenchmark(tf.test.Benchmark): + + def runScan(self, n): + elems = np.arange(n) + start_time = time.time() + sum_op = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1) + with tf.Session() as sess: + sess.run(sum_op) + wall_time = time.time() - start_time + + self.report_benchmark( + name='scan', + iters=n, + wall_time=wall_time) + + def benchmarkScan16000(self): + self.runScan(16000) + + def benchmarkScan32000(self): + self.runScan(32000) + + def benchmarkScan64000(self): + self.runScan(64000) + + def benchmarkScan128000(self): + self.runScan(128000) + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/python/keras/applications/densenet/__init__.py b/tensorflow/contrib/eager/python/examples/scan/scan_test.py similarity index 52% rename from tensorflow/python/keras/applications/densenet/__init__.py rename to tensorflow/contrib/eager/python/examples/scan/scan_test.py index 6b8ea83920733a3a442171616ab460ffaf831521..a02fc24c79dae6c2565db8b138b1d7391d169ed8 100644 --- a/tensorflow/python/keras/applications/densenet/__init__.py +++ b/tensorflow/contrib/eager/python/examples/scan/scan_test.py @@ -12,18 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""DenseNet Keras applications.""" - +"""Unit test for tf.scan under eager execution.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.densenet import decode_predictions -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet121 -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet169 -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet201 -from tensorflow.python.keras._impl.keras.applications.densenet import preprocess_input +import time + +import numpy as np +import tensorflow as tf + + +class ScanBenchmark(tf.test.Benchmark): + + def runScan(self, n): + elems = np.arange(n) + start_time = time.time() + _ = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1) + wall_time = time.time() - start_time + + self.report_benchmark( + name='scan', + iters=n, + wall_time=wall_time) + + def benchmarkScan16000(self): + self.runScan(16000) + + def benchmarkScan32000(self): + self.runScan(32000) + + def benchmarkScan64000(self): + self.runScan(64000) + + def benchmarkScan128000(self): + self.runScan(128000) + -del absolute_import -del division -del print_function +if __name__ == '__main__': + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index 1e4746d01ca1a8d13162844bc064c479c7184237..8ac553e0ae71382966d03d9ef4429adf5137b369 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -36,8 +36,8 @@ from third_party.examples.eager.spinn import spinn from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import test from tensorflow.python.framework import test_util -from tensorflow.python.training import checkpointable_utils from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import util as checkpointable_utils # pylint: enable=g-bad-import-order diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 907f9204c2d31a652ca2a0539a23db4722b4e154..1ae6415d5ecb03ef97cdf734c808e3f728dafcb0 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -30,7 +30,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.training import checkpointable +from tensorflow.python.training.checkpointable import base as checkpointable _to_replace = re.compile("[^A-Za-z0-9.]") diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index f0fe4ce8c53bb80c03a3f0de37078bcdb975a0b4..98a98a8d358d8a7d7a06505ed1a7d4c0ff1e18f4 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -30,8 +30,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import summary_ops_v2 as summary_ops -from tensorflow.python.training import checkpointable_utils from tensorflow.python.training import training_util +from tensorflow.python.training.checkpointable import util as checkpointable_utils class MetricsTest(test.TestCase): @@ -146,8 +146,6 @@ class MetricsTest(test.TestCase): self.assertAllEqual(2.0, m2.result()) def testNamesWithSpaces(self): - # Verify two metrics with the same class and name don't - # accidentally share state. m1 = metrics.Mean("has space") m1(0) self.assertEqual(m1.name, "has space") @@ -186,8 +184,8 @@ class MetricsTest(test.TestCase): self.assertEqual(self.evaluate(value), 2.5) def testTwoMeansGraph(self): - # Verify two metrics with the same class and name don't - # accidentally share state. + # Verify two metrics with the same name in the same graph raises a + # ValueError. with context.graph_mode(): m1 = metrics.Mean() m1(0) diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index 44828bea50c660815e457f21a1990cd706c40876..f801d9a47b2f831a48d9b6335c69612c1356d800 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -23,9 +23,8 @@ import os import weakref from tensorflow.python.eager import context -from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import ops -from tensorflow.python.keras._impl.keras.engine import base_layer as keras_base_layer +from tensorflow.python.keras.engine import base_layer as keras_base_layer from tensorflow.python.layers import base from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging @@ -33,6 +32,7 @@ from tensorflow.python.training import checkpoint_utils from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util from tensorflow.python.util import deprecation +from tensorflow.python.util import function_utils # pylint: disable=protected-access # Explanation for protected-access disable: Network has lots of same-class and @@ -545,10 +545,10 @@ class Sequential(Network): def add(self, layer_func): if isinstance(layer_func, base.Layer): - args = estimator_util.fn_args(layer_func.call) + args = function_utils.fn_args(layer_func.call) self.track_layer(layer_func) elif callable(layer_func): - args = estimator_util.fn_args(layer_func) + args = function_utils.fn_args(layer_func) else: raise TypeError( "Sequential.add() takes only tf.layers.Layer objects or callables; " diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index 6a51d03de52914d2ad0ac3ad05d1ba01d856ad9a..c92bd15b253b67a3301cd562046a4467e1bf877d 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -30,8 +30,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.training import checkpointable_utils from tensorflow.python.training import training_util +from tensorflow.python.training.checkpointable import util as checkpointable_utils # pylint: disable=not-callable diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 79dd117854e5fe9f066f671d8ce62e08579e0ed9..5826700c73e255198e9a6974ca240ba55e438a26 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -120,9 +120,9 @@ from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Vari from tensorflow.python.ops.variable_scope import EagerVariableStore from tensorflow.python.ops import script_ops from tensorflow.python.ops import template -from tensorflow.python.training.checkpointable import Checkpointable -from tensorflow.python.training.checkpointable_utils import CheckpointableSaver -from tensorflow.python.training.checkpointable_utils import Checkpoint +from tensorflow.python.training.checkpointable.base import Checkpointable +from tensorflow.python.training.checkpointable.util import CheckpointableSaver +from tensorflow.python.training.checkpointable.util import Checkpoint from tensorflow.python.util.all_util import remove_undocumented py_func = script_ops.eager_py_func diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index e80ccbb74d8623e977a98cb7fa5eb41f3c9bf250..db50b33af2e4f1cc6575d4b0d416d6d2669b5c35 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -57,7 +57,7 @@ class TFETest(test_util.TensorFlowTestCase): return math_ops.multiply(x, x) grad = tfe.gradients_function(square) - self.assertEquals([6], [x.numpy() for x in grad(3)]) + self.assertEquals([6], [x.numpy() for x in grad(3.)]) def testGradOfGrad(self): @@ -66,7 +66,7 @@ class TFETest(test_util.TensorFlowTestCase): grad = tfe.gradients_function(square) gradgrad = tfe.gradients_function(lambda x: grad(x)[0]) - self.assertEquals([2], [x.numpy() for x in gradgrad(3)]) + self.assertEquals([2], [x.numpy() for x in gradgrad(3.)]) def testCustomGrad(self): @@ -80,7 +80,7 @@ class TFETest(test_util.TensorFlowTestCase): return y, grad_fn grad = tfe.gradients_function(f) - self.assertEquals([12], [x.numpy() for x in grad(3)]) + self.assertEquals([12], [x.numpy() for x in grad(3.)]) def testGPU(self): if tfe.num_gpus() <= 0: diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 571e2e3a5df08e09172ec7b6885c5c972b5abfb6..d5d2abf8c4c82374842ed2e10a849765a6dddd3b 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -14,11 +14,14 @@ py_library( srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ + ":baseline", ":boosted_trees", ":dnn", ":dnn_linear_combined", + ":export", ":extenders", ":head", + ":hooks", ":linear", ":logit_fns", ":multi_head", @@ -28,6 +31,49 @@ py_library( ], ) +py_library( + name = "baseline", + srcs = ["python/estimator/baseline.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:baseline", + ], +) + +py_test( + name = "baseline_test", + size = "small", + srcs = ["python/estimator/baseline_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + ":baseline", + ":head", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:session", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:metric_keys", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_library( name = "boosted_trees", srcs = ["python/estimator/boosted_trees.py"], @@ -180,6 +226,43 @@ py_test( ], ) +py_library( + name = "export", + srcs = [ + "python/estimator/export.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator:model_fn", + ], +) + +py_test( + name = "export_test", + size = "medium", + srcs = ["python/estimator/export_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], # b/62863147 + deps = [ + ":export", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:metrics", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:session", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/saved_model:tag_constants", + ], +) + py_library( name = "head", srcs = [ @@ -239,6 +322,36 @@ py_test( ], ) +py_library( + name = "hooks", + srcs = [ + "python/estimator/hooks.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/estimator:estimator_py", + ], +) + +py_test( + name = "hooks_test", + size = "medium", + srcs = ["python/estimator/hooks_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":hooks", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator:estimator_py", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_library( name = "linear", srcs = ["python/estimator/linear.py"], @@ -284,9 +397,9 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:framework_ops", + "//tensorflow/python:util", "//tensorflow/python/estimator:dnn", "//tensorflow/python/estimator:linear", - "//tensorflow/python/estimator:util", ], ) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index d43b3ea6bf2718ab7f5317b95fdd93ec9917dc69..788ac5ca7046d6dd30a3d5520b243944532622fa 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -19,11 +19,14 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.estimator.python.estimator.baseline import * from tensorflow.contrib.estimator.python.estimator.boosted_trees import * from tensorflow.contrib.estimator.python.estimator.dnn import * from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * +from tensorflow.contrib.estimator.python.estimator.export import * from tensorflow.contrib.estimator.python.estimator.extenders import * from tensorflow.contrib.estimator.python.estimator.head import * +from tensorflow.contrib.estimator.python.estimator.hooks import * from tensorflow.contrib.estimator.python.estimator.linear import * from tensorflow.contrib.estimator.python.estimator.logit_fns import * from tensorflow.contrib.estimator.python.estimator.multi_head import * @@ -38,12 +41,14 @@ _allowed_symbols = [ 'binary_classification_head', 'clip_gradients_by_norm', 'forward_features', + 'InMemoryEvaluatorHook', 'logistic_regression_head', 'multi_class_head', 'multi_head', 'multi_label_head', 'poisson_regression_head', 'regression_head', + 'BaselineEstimator', 'DNNEstimator', 'DNNLinearCombinedEstimator', 'LinearEstimator', @@ -56,6 +61,8 @@ _allowed_symbols = [ 'TowerOptimizer', 'RNNClassifier', 'RNNEstimator', + 'export_saved_model_for_mode', + 'export_all_saved_models', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/baseline.py b/tensorflow/contrib/estimator/python/estimator/baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..beffbee73064b9ef425b115317c43e29477b19af --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/baseline.py @@ -0,0 +1,98 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Baseline estimators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator.canned import baseline + + +class BaselineEstimator(estimator.Estimator): + """An estimator that can establish a simple baseline. + + The estimator uses a user-specified head. + + This estimator ignores feature values and will learn to predict the average + value of each label. E.g. for single-label classification problems, this will + predict the probability distribution of the classes as seen in the labels. + For multi-label classification problems, it will predict the ratio of examples + that contain each class. + + Example: + + ```python + + # Build baseline multi-label classifier. + estimator = BaselineEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3)) + + # Input builders + def input_fn_train: # returns x, y (where y represents label's class index). + pass + + def input_fn_eval: # returns x, y (where y represents label's class index). + pass + + # Fit model. + estimator.train(input_fn=input_fn_train) + + # Evaluates cross entropy between the test and train labels. + loss = classifier.evaluate(input_fn=input_fn_eval)["loss"] + + # For each class, predicts the ratio of training examples that contain the + # class. + predictions = classifier.predict(new_samples) + + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * if `weight_column` passed to the `head` constructor is not `None`, a feature + with `key=weight_column` whose value is a `Tensor`. + """ + + def __init__(self, + head, + model_dir=None, + optimizer='Ftrl', + config=None): + """Initializes a BaselineEstimator instance. + + Args: + head: A `_Head` instance constructed with a method such as + `tf.contrib.estimator.multi_label_head`. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + optimizer: String, `tf.Optimizer` object, or callable that creates the + optimizer to use for training. If not specified, will use + `FtrlOptimizer` with a default learning rate of 0.3. + config: `RunConfig` object to configure the runtime settings. + """ + def _model_fn(features, labels, mode, config): + return baseline._baseline_model_fn( # pylint: disable=protected-access + features=features, + labels=labels, + mode=mode, + head=head, + optimizer=optimizer, + config=config) + super(BaselineEstimator, self).__init__( + model_fn=_model_fn, + model_dir=model_dir, + config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d0e3e670f7332811c1bfdaea65b0308ce59ade59 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py @@ -0,0 +1,430 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for baseline.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import tempfile + +import numpy as np +import six + +from tensorflow.contrib.estimator.python.estimator import baseline +from tensorflow.contrib.estimator.python.estimator import head as head_lib +from tensorflow.python.client import session as tf_session +from tensorflow.python.estimator.canned import metric_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import optimizer +from tensorflow.python.training import saver + +# Names of variables created by model. +BIAS_NAME = 'baseline/bias' + + +def assert_close(expected, actual, rtol=1e-04, name='assert_close'): + with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope: + expected = ops.convert_to_tensor(expected, name='expected') + actual = ops.convert_to_tensor(actual, name='actual') + rdiff = math_ops.abs(expected - actual, 'diff') / math_ops.abs(expected) + rtol = ops.convert_to_tensor(rtol, name='rtol') + return check_ops.assert_less( + rdiff, + rtol, + data=('Condition expected =~ actual did not hold element-wise:' + 'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff, + 'rtol = ', rtol,), + name=scope) + + +def save_variables_to_ckpt(model_dir): + init_all_op = [variables.global_variables_initializer()] + with tf_session.Session() as sess: + sess.run(init_all_op) + saver.Saver().save(sess, os.path.join(model_dir, 'model.ckpt')) + + +def _baseline_estimator_fn( + weight_column=None, label_dimension=1, *args, **kwargs): + """Returns a BaselineEstimator that uses regression_head.""" + return baseline.BaselineEstimator( + head=head_lib.regression_head( + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), + *args, **kwargs) + + +class BaselineEstimatorEvaluationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def test_evaluation_batch(self): + """Tests evaluation for batch_size==2.""" + with ops.Graph().as_default(): + variables.Variable([13.0], name=BIAS_NAME) + variables.Variable( + 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn(model_dir=self._model_dir) + eval_metrics = baseline_estimator.evaluate( + input_fn=lambda: ({'age': ((1,), (1,))}, ((10.,), (10.,))), steps=1) + + # Logit is bias = 13, while label is 10. + # Loss per example is 3**2 = 9. + # Training loss is the sum over batch = 9 + 9 = 18 + # Average loss is the average over batch = 9 + self.assertDictEqual({ + metric_keys.MetricKeys.LOSS: 18., + metric_keys.MetricKeys.LOSS_MEAN: 9., + ops.GraphKeys.GLOBAL_STEP: 100 + }, eval_metrics) + + def test_evaluation_weights(self): + """Tests evaluation with weights.""" + with ops.Graph().as_default(): + variables.Variable([13.0], name=BIAS_NAME) + variables.Variable( + 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + def _input_fn(): + features = {'age': ((1,), (1,)), 'weights': ((1.,), (2.,))} + labels = ((10.,), (10.,)) + return features, labels + + baseline_estimator = _baseline_estimator_fn( + weight_column='weights', + model_dir=self._model_dir) + eval_metrics = baseline_estimator.evaluate(input_fn=_input_fn, steps=1) + + # Logit is bias = 13, while label is 10. + # Loss per example is 3**2 = 9. + # Training loss is the weighted sum over batch = 9 + 2*9 = 27 + # average loss is the weighted average = 9 + 2*9 / (1 + 2) = 9 + self.assertDictEqual({ + metric_keys.MetricKeys.LOSS: 27., + metric_keys.MetricKeys.LOSS_MEAN: 9., + ops.GraphKeys.GLOBAL_STEP: 100 + }, eval_metrics) + + def test_evaluation_for_multi_dimensions(self): + label_dim = 2 + with ops.Graph().as_default(): + variables.Variable([46.0, 58.0], name=BIAS_NAME) + variables.Variable(100, name='global_step', dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn( + label_dimension=label_dim, + model_dir=self._model_dir) + input_fn = numpy_io.numpy_input_fn( + x={ + 'age': np.array([[2., 4., 5.]]), + }, + y=np.array([[46., 58.]]), + batch_size=1, + num_epochs=None, + shuffle=False) + eval_metrics = baseline_estimator.evaluate(input_fn=input_fn, steps=1) + + self.assertItemsEqual( + (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN, + ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys()) + + # Logit is bias which is [46, 58] + self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS]) + + +class BaselineEstimatorPredictTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def test_1d(self): + """Tests predict when all variables are one-dimensional.""" + with ops.Graph().as_default(): + variables.Variable([.2], name=BIAS_NAME) + variables.Variable(100, name='global_step', dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn(model_dir=self._model_dir) + + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': np.array([[2.]])}, + y=None, + batch_size=1, + num_epochs=1, + shuffle=False) + predictions = baseline_estimator.predict(input_fn=predict_input_fn) + predicted_scores = list([x['predictions'] for x in predictions]) + # x * weight + bias = 2. * 10. + .2 = 20.2 + self.assertAllClose([[.2]], predicted_scores) + + def testMultiDim(self): + """Tests predict when all variables are multi-dimenstional.""" + batch_size = 2 + label_dimension = 3 + with ops.Graph().as_default(): + variables.Variable( # shape=[label_dimension] + [.2, .4, .6], name=BIAS_NAME) + variables.Variable(100, name='global_step', dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn( + label_dimension=label_dimension, + model_dir=self._model_dir) + + predict_input_fn = numpy_io.numpy_input_fn( + # x shape=[batch_size, x_dim] + x={'x': np.array([[1., 2., 3., 4.], [5., 6., 7., 8.]])}, + y=None, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + predictions = baseline_estimator.predict(input_fn=predict_input_fn) + predicted_scores = list([x['predictions'] for x in predictions]) + # score = bias, shape=[batch_size, label_dimension] + self.assertAllClose([[0.2, 0.4, 0.6], [0.2, 0.4, 0.6]], + predicted_scores) + + +class BaselineEstimatorIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn, + input_dimension, label_dimension, prediction_length): + feature_columns = [ + feature_column_lib.numeric_column('x', shape=(input_dimension,)) + ] + est = _baseline_estimator_fn( + label_dimension=label_dimension, + model_dir=self._model_dir) + + # TRAIN + # learn y = x + est.train(train_input_fn, steps=200) + + # EVALUTE + scores = est.evaluate(eval_input_fn) + self.assertEqual(200, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores)) + + # PREDICT + predictions = np.array( + [x['predictions'] for x in est.predict(predict_input_fn)]) + self.assertAllEqual((prediction_length, label_dimension), predictions.shape) + + # EXPORT + feature_spec = feature_column_lib.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = est.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def test_numpy_input_fn(self): + """Tests complete flow with numpy_input_fn.""" + label_dimension = 2 + input_dimension = label_dimension + batch_size = 10 + prediction_length = batch_size + data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) + data = data.reshape(batch_size, label_dimension) + + train_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=None, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + input_dimension=input_dimension, + label_dimension=label_dimension, + prediction_length=prediction_length) + + +class BaselineEstimatorTrainingTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _mock_optimizer(self, expected_loss=None): + expected_var_names = [ + '%s:0' % BIAS_NAME + ] + + def _minimize(loss, global_step=None, var_list=None): + trainable_vars = var_list or ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertItemsEqual(expected_var_names, + [var.name for var in trainable_vars]) + + # Verify loss. We can't check the value directly, so we add an assert op. + self.assertEquals(0, loss.shape.ndims) + if expected_loss is None: + if global_step is not None: + return distribute_lib.increment_var(global_step) + return control_flow_ops.no_op() + assert_loss = assert_close( + math_ops.to_float(expected_loss, name='expected'), + loss, + name='assert_loss') + with ops.control_dependencies((assert_loss,)): + if global_step is not None: + return distribute_lib.increment_var(global_step) + return control_flow_ops.no_op() + + mock_optimizer = test.mock.NonCallableMock( + spec=optimizer.Optimizer, + wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer')) + mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize) + + # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks. + # So, return mock_optimizer itself for deepcopy. + mock_optimizer.__deepcopy__ = lambda _: mock_optimizer + return mock_optimizer + + def _assert_checkpoint(self, + label_dimension, + expected_global_step, + expected_bias=None): + shapes = { + name: shape + for (name, shape) in checkpoint_utils.list_variables(self._model_dir) + } + + self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP]) + self.assertEqual(expected_global_step, + checkpoint_utils.load_variable(self._model_dir, + ops.GraphKeys.GLOBAL_STEP)) + + self.assertEqual([label_dimension], shapes[BIAS_NAME]) + if expected_bias is not None: + self.assertEqual(expected_bias, + checkpoint_utils.load_variable(self._model_dir, + BIAS_NAME)) + + def testFromScratch(self): + # Create BaselineRegressor. + label = 5. + age = 17 + # loss = (logits - label)^2 = (0 - 5.)^2 = 25. + mock_optimizer = self._mock_optimizer(expected_loss=25.) + baseline_estimator = _baseline_estimator_fn( + model_dir=self._model_dir, + optimizer=mock_optimizer) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + baseline_estimator.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + label_dimension=1, + expected_global_step=num_steps, + expected_bias=[0.]) + + def testFromCheckpoint(self): + # Create initial checkpoint. + bias = 7.0 + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable([bias], name=BIAS_NAME) + variables.Variable( + initial_global_step, + name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + # logits = bias = 6. + # loss = (logits - label)^2 = (7 - 5)^2 = 4 + mock_optimizer = self._mock_optimizer(expected_loss=4.) + baseline_estimator = _baseline_estimator_fn( + model_dir=self._model_dir, + optimizer=mock_optimizer) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + baseline_estimator.train( + input_fn=lambda: ({'age': ((17,),)}, ((5.,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + label_dimension=1, + expected_global_step=initial_global_step + num_steps, + expected_bias=[bias]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py index cf6e3329d2e27735d8759cc2ab3726e8c624c6ae..7ff25b95c079c7e06d29e874bcaa0d2c13e7167e 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn.py @@ -93,7 +93,7 @@ class DNNEstimator(estimator.Estimator): dropout=None, input_layer_partitioner=None, config=None): - """Initializes a `DNNClassifier` instance. + """Initializes a `DNNEstimator` instance. Args: head: A `_Head` instance constructed with a method such as diff --git a/tensorflow/contrib/estimator/python/estimator/export.py b/tensorflow/contrib/estimator/python/estimator/export.py new file mode 100644 index 0000000000000000000000000000000000000000..03cf6f107c1c5589522d7be4946562a466740b0e --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/export.py @@ -0,0 +1,223 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wrapper for methods to export train/eval graphs from Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import model_fn as model_fn_lib + + +def export_saved_model_for_mode( + estimator, export_dir_base, input_receiver_fn, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False, + mode=model_fn_lib.ModeKeys.PREDICT): + # pylint: disable=line-too-long + """Exports a single train/eval/predict graph as a SavedModel. + + For a detailed guide, see + @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}. + + Sample usage: + ```python + classifier = tf.estimator.LinearClassifier( + feature_columns=[age, language]) + classifier.train(input_fn=input_fn, steps=1000) + + feature_spec = { + 'age': tf.placeholder(dtype=tf.int64), + 'language': array_ops.placeholder(dtype=tf.string) + } + label_spec = tf.placeholder(dtype=dtypes.int64) + + train_rcvr_fn = tf.contrib.estimator.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + export_dir = tf.contrib.estimator.export_saved_model_for_mode( + classifier, + export_dir_base='my_model/', + input_receiver_fn=train_rcvr_fn, + mode=model_fn_lib.ModeKeys.TRAIN) + + # export_dir is a timestamped directory with the SavedModel, which + # can be used for serving, analysis with TFMA, or directly loaded in. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + weights = graph.get_tensor_by_name(''linear/linear_model/age/weights') + ... + ``` + + This method is a wrapper for _export_all_saved_models, and wraps a raw + input_receiver_fn in a dictionary to pass in to that function. + See _export_all_saved_models for full docs. + + See tf.contrib.estimator.export_saved_model_for_mode for the currently + exposed version of this function. + + Args: + estimator: an instance of tf.estimator.Estimator + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn: a function that takes no argument and + returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + mode: tf.estimator.ModeKeys value indicating with mode will be exported. + + Returns: + The string path to the exported directory. + + Raises: + ValueError: if input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + + # pylint: disable=protected-access + return estimator._export_saved_model_for_mode( + export_dir_base, input_receiver_fn, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs, + mode=mode) + # pylint: enable=protected-access + + +def export_all_saved_models( + estimator, export_dir_base, input_receiver_fn_map, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False): + # pylint: disable=line-too-long + """Exports requested train/eval/predict graphs as separate SavedModels. + + See tf.contrib.estimator.export_all_saved_models for the currently + exposed version of this function. + + For each mode passed in via the input_receiver_fn_map, + this method builds a new graph by calling the input_receiver_fn to obtain + feature and label `Tensor`s. Next, this method calls the `Estimator`'s + model_fn in the passed mode to generate the model graph based on + those features and labels, and restores the given checkpoint + (or, lacking that, the most recent checkpoint) into the graph. + Only one of the modes is used for saving variables to the SavedModel + (order of preference: TRAIN, EVAL, then PREDICT), such that up to three + MetaGraphDefs are saved with a single set of variables in a single + SavedModel directory. + + For prediction, the exported `MetaGraphDef` will provide one `SignatureDef` + for each element of the export_outputs dict returned from the model_fn, + named using the same keys. One of these keys is always + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which + signature will be served when a serving request does not specify one. + For each signature, the outputs are provided by the corresponding + `ExportOutput`s, and the inputs are always the input receivers provided by + the serving_input_receiver_fn. + + For training and evaluation, the train_op is stored in an extra collection, + and loss, metrics, and predictions are included in a SignatureDef for the + mode in question. + + Extra assets may be written into the SavedModel via the assets_extra + argument. This should be a dict, where each key gives a destination path + (including the filename) relative to the assets.extra directory. The + corresponding value gives the full path of the source file to be copied. + For example, the simple case of copying a single file without renaming it + is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + + Sample usage: + ```python + classifier = tf.estimator.LinearClassifier( + feature_columns=[age, language]) + classifier.train(input_fn=input_fn) + + feature_spec = { + 'age': tf.placeholder(dtype=tf.int64), + 'language': array_ops.placeholder(dtype=tf.string) + } + label_spec = tf.placeholder(dtype=dtypes.int64) + + train_rcvr_fn = tf.contrib.estimator.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + serve_rcvr_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn( + feature_spec) + + rcvr_fn_map = { + model_fn_lib.ModeKeys.TRAIN: train_rcvr_fn, + model_fn_lib.ModeKeys.PREDICT: serve_rcvr_fn, + } + + export_dir = tf.contrib.estimator.export_all_saved_models( + classifier, + export_dir_base='my_model/', + input_receiver_fn_map=rcvr_fn_map) + + # export_dirs is a dict of directories with SavedModels, which + # can be used for serving, analysis with TFMA, or directly loaded in. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + weights = graph.get_tensor_by_name('linear/linear_model/age/weights') + ... + ``` + + Args: + estimator: an instance of tf.estimator.Estimator + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn + mappings, where the input_receiver_fn is a function that takes no + argument and returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + + Returns: + A dict of tf.estimator.ModeKeys value to string path for each exported + directory. + + Raises: + ValueError: if any input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + + # pylint: disable=protected-access + return estimator._export_all_saved_models( + export_dir_base, input_receiver_fn_map, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs) + # pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/export_test.py b/tensorflow/contrib/estimator/python/estimator/export_test.py new file mode 100644 index 0000000000000000000000000000000000000000..050821ee672f30a6926c4a0a0e48915515d9afd7 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/export_test.py @@ -0,0 +1,373 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib wrapping of export_saved_model_for_mode functionality. + +These are direct copies of the tests included in core, with import locations +changed. These should be removed when the functionality in core is part of the +public API. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +from tensorflow.contrib.estimator.python.estimator import export as contrib_export +from tensorflow.python.client import session +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import tag_constants +from tensorflow.python.training import training +from tensorflow.python.util import compat + + +def _model_fn_for_export_tests(features, labels, mode): + _, _ = features, labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + update_global_step = state_ops.assign_add(training.get_global_step(), 1) + with ops.control_dependencies([update_global_step]): + train_op = constant_op.constant(2.) + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + loss=constant_op.constant(1.), + train_op=train_op, + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + + +def _x_y_input_fn(): + return ({'x': constant_op.constant([[1], [1]]), + 'y': constant_op.constant([[2], [2]])}, + constant_op.constant([[1], [1]])) + + +def _model_fn_with_x_y(features, labels, mode): + _ = labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + if mode == model_fn_lib.ModeKeys.PREDICT: + variables.Variable(36., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + else: + prefix = 'eval_' if mode == model_fn_lib.ModeKeys.EVAL else '' + + multiplied = math_ops.multiply( + features['x'], features['y'], name='{}multiplied'.format(prefix)) + metrics = {'mean': metrics_lib.mean(features['x'] - features['y'], + name='{}mean'.format(prefix))} + variables.Variable(1., name='later_var') + variables.Variable(3., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=multiplied, + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + eval_metric_ops=metrics) + + +def _get_serving_input_receiver_fn(): + feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), + 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} + return export.build_parsing_serving_input_receiver_fn(feature_spec) + + +def _get_supervised_input_receiver_fn(): + feature_spec = { + 'x': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_x'), + 'y': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_y') + } + label_spec = array_ops.placeholder( + dtype=dtypes.float32, shape=[1], name='truth') + + return export.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + +class EstimatorExportTest(test.TestCase): + + def test_export_saved_model_train(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.TRAIN) + + def test_export_saved_model_eval(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.EVAL) + + def test_export_saved_model_predict(self): + self._test_export_saved_model_for_mode( + _get_serving_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT) + + def _test_export_saved_model_for_mode(self, input_receiver_fn, mode): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_for_export_tests) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = contrib_export.export_saved_model_for_mode( + est, export_dir_base, input_receiver_fn, mode=mode) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + self._validate_exported_files(export_dir) + + # Restore, to validate that the export was well-formed. + tag_set = model_fn_lib.EXPORT_TAG_MAP[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('name_collision_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_receiver_map(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertFalse('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_train_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertTrue('mean/update_op' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_eval_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertTrue('eval_mean/value' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_no_serving(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + # TODO(karmel): is this the desired behavior when names are shared? + self.assertTrue('feature_x_1' in graph_ops) + self.assertTrue('feature_y_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_three_defs(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + # Restore, to validate that the export was well-formed. + for tag_set in model_fn_lib.EXPORT_TAG_MAP.values(): + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('global_step/Assign' in graph_ops) + self.assertTrue('global_step/Initializer/zeros' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_all_vars(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_name_collision(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertEqual(3, collection_vars[-1].eval()) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + # This is a non-obvious detail: when we load the estimator spec + # for predict, name_collision gets set to 36. However, we then restore + # from checkpoint, which should overwrite that var and make it the 3 + # from training. In practice, this would not be a good way to write + # a model_fn, but leaving this check in for now to ensure consistency + # with what would happen given our current order of spec, then + # checkpoint. + self.assertEqual(3, collection_vars[-1].eval()) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def _test_export_all_saved_models(self, input_receiver_fn_map): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_with_x_y) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = contrib_export.export_all_saved_models( + est, export_dir_base, input_receiver_fn_map) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + + self._validate_exported_files(export_dir) + + return export_dir, tmpdir + + def _validate_exported_files(self, export_dir): + self.assertTrue(gfile.Exists(export_dir)) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('saved_model.pb')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.index')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.data-00000-of-00001')))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py index 201699ed775f701bc9f215fff11a688175d51645..bf08be09e7baf63e507a6a4db6a91e7b6bb20b74 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders.py @@ -22,12 +22,12 @@ import six from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator import util as estimator_util from tensorflow.python.estimator.export.export_output import PredictOutput from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.ops import clip_ops from tensorflow.python.training import optimizer as optimizer_lib +from tensorflow.python.util import function_utils _VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config']) @@ -330,7 +330,7 @@ class _TransformGradients(optimizer_lib.Optimizer): def _verify_metric_fn_args(metric_fn): - args = set(estimator_util.fn_args(metric_fn)) + args = set(function_utils.fn_args(metric_fn)) invalid_args = list(args - _VALID_METRIC_FN_ARGS) if invalid_args: raise ValueError('metric_fn (%s) has following not expected args: %s' % @@ -339,7 +339,7 @@ def _verify_metric_fn_args(metric_fn): def _call_metric_fn(metric_fn, features, labels, predictions, config): """Calls metric fn with proper arguments.""" - metric_fn_args = estimator_util.fn_args(metric_fn) + metric_fn_args = function_utils.fn_args(metric_fn) kwargs = {} if 'features' in metric_fn_args: kwargs['features'] = features diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 5d19bf4714ff6fcdb77948e14839a37a2ec56b75..8b97f86db19a1bc2d9f17c9935e6678844daf177 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import six + from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import metric_keys @@ -72,6 +74,33 @@ def multi_class_head(n_classes, shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to `loss_fn`. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.multi_class_head(n_classes=3) + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.multi_class_head(n_classes=3) + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: n_classes: Number of classes, must be greater than 2 (for 2 classes, use `binary_classification_head`). @@ -139,6 +168,33 @@ def binary_classification_head( shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to `loss_fn`. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.binary_classification_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.binary_classification_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -211,6 +267,33 @@ def regression_head(weight_column=None, https://en.wikipedia.org/wiki/Generalized_linear_model#Link_function Namely, for poisson regression, set `inverse_link_fn=tf.exp`. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.regression_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.regression_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -270,6 +353,33 @@ def poisson_regression_head( This is implemented as a generalized linear model, see https://en.wikipedia.org/wiki/Generalized_linear_model. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.poisson_regression_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.poisson_regression_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -337,6 +447,33 @@ def logistic_regression_head( This is implemented as a generalized linear model, see https://en.wikipedia.org/wiki/Generalized_linear_model. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.logistic_regression_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.logistic_regression_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -375,6 +512,7 @@ def multi_label_head(n_classes, label_vocabulary=None, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, + classes_for_class_based_metrics=None, name=None): """Creates a `_Head` for multi-label classification. @@ -406,6 +544,33 @@ def multi_label_head(n_classes, shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to `loss_fn`. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.multi_label_head(n_classes=3) + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.multi_label_head(n_classes=3) + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: n_classes: Number of classes, must be greater than 1 (for 1 class, use `binary_classification_head`). @@ -427,6 +592,10 @@ def multi_label_head(n_classes, reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by batch size. See `tf.losses.Reduction`. loss_fn: Optional loss function. + classes_for_class_based_metrics: List of integer class IDs or string class + names for which per-class metrics are evaluated. If integers, all must be + in the range `[0, n_classes - 1]`. If strings, all must be in + `label_vocabulary`. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -434,8 +603,8 @@ def multi_label_head(n_classes, An instance of `_Head` for multi-label classification. Raises: - ValueError: if `n_classes`, `thresholds`, `loss_reduction` or `loss_fn` is - invalid. + ValueError: if `n_classes`, `thresholds`, `loss_reduction`, `loss_fn` or + `metric_class_ids` is invalid. """ thresholds = tuple(thresholds) if thresholds else tuple() if n_classes is None or n_classes < 2: @@ -460,10 +629,31 @@ def multi_label_head(n_classes, if (loss_reduction not in losses.Reduction.all() or loss_reduction == losses.Reduction.NONE): raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) + classes_for_class_based_metrics = tuple( + [] if classes_for_class_based_metrics is None + else classes_for_class_based_metrics) + if classes_for_class_based_metrics: + if isinstance(classes_for_class_based_metrics[0], six.string_types): + if not label_vocabulary: + raise ValueError( + 'label_vocabulary must be provided when ' + 'classes_for_class_based_metrics are sting.') + class_ids = [] + for class_string in classes_for_class_based_metrics: + class_ids.append(label_vocabulary.index(class_string)) + classes_for_class_based_metrics = tuple(class_ids) + else: + for class_id in classes_for_class_based_metrics: + if (class_id < 0) or (class_id >= n_classes): + raise ValueError( + 'All classes_for_class_based_metrics must be in range [0, {}]. ' + 'Given: {}'.format(n_classes - 1, class_id)) return _MultiLabelHead( n_classes=n_classes, weight_column=weight_column, thresholds=thresholds, label_vocabulary=label_vocabulary, loss_reduction=loss_reduction, - loss_fn=loss_fn, name=name) + loss_fn=loss_fn, + classes_for_class_based_metrics=classes_for_class_based_metrics, + name=name) class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access @@ -476,6 +666,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access label_vocabulary=None, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, + classes_for_class_based_metrics=None, name=None): self._n_classes = n_classes self._weight_column = weight_column @@ -483,6 +674,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access self._label_vocabulary = label_vocabulary self._loss_reduction = loss_reduction self._loss_fn = loss_fn + self._classes_for_class_based_metrics = classes_for_class_based_metrics self._name = name @property @@ -560,10 +752,10 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access weights=weights, processed_labels=processed_labels) - def create_estimator_spec( + def _create_tpu_estimator_spec( self, features, mode, logits, labels=None, optimizer=None, train_op_fn=None, regularization_losses=None): - """Returns an `EstimatorSpec`. + """Returns an `model_fn._TPUEstimatorSpec`. Args: features: Input `dict` of `Tensor` or `SparseTensor` objects. @@ -586,7 +778,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to avoid scaling errors. Returns: - `EstimatorSpec`. + `model_fn._TPUEstimatorSpec`. Raises: ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN mode, or if both are set. @@ -606,7 +798,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access classifier_output = head_lib._classification_output( # pylint:disable=protected-access scores=probabilities, n_classes=self._n_classes, label_vocabulary=self._label_vocabulary) - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs={ @@ -629,16 +821,18 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access # Eval. if mode == model_fn.ModeKeys.EVAL: - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=model_fn.ModeKeys.EVAL, predictions=predictions, loss=regularized_training_loss, - eval_metric_ops=self._eval_metric_ops( - labels=processed_labels, - probabilities=probabilities, - weights=weights, - unreduced_loss=unreduced_loss, - regularization_loss=regularization_loss)) + eval_metrics=head_lib._create_eval_metrics_tuple( # pylint:disable=protected-access + self._eval_metric_ops, { + 'labels': processed_labels, + 'probabilities': probabilities, + 'weights': weights, + 'unreduced_loss': unreduced_loss, + 'regularization_loss': regularization_loss, + })) # Train. if optimizer is not None: @@ -672,7 +866,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access summary.scalar( head_lib._summary_key(self._name, keys.LOSS_REGULARIZATION), # pylint:disable=protected-access regularization_loss) - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, @@ -735,4 +929,36 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access weights=weights, threshold=threshold, name=recall_key)) + for class_id in self._classes_for_class_based_metrics: + batch_rank = array_ops.rank(probabilities) - 1 + begin = array_ops.concat( + [array_ops.zeros([batch_rank], dtype=dtypes.int32), [class_id]], + axis=0) + size = array_ops.concat( + [-1 * array_ops.ones([batch_rank], dtype=dtypes.int32), [1]], + axis=0) + class_probabilities = array_ops.slice( + probabilities, begin=begin, size=size) + class_labels = array_ops.slice(labels, begin=begin, size=size) + prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id + metric_ops[head_lib._summary_key(self._name, prob_key)] = ( # pylint:disable=protected-access + head_lib._predictions_mean( # pylint:disable=protected-access + predictions=class_probabilities, + weights=weights, + name=prob_key)) + auc_key = keys.AUC_AT_CLASS % class_id + metric_ops[head_lib._summary_key(self._name, auc_key)] = ( # pylint:disable=protected-access + head_lib._auc( # pylint:disable=protected-access + labels=class_labels, + predictions=class_probabilities, + weights=weights, + name=auc_key)) + auc_pr_key = keys.AUC_PR_AT_CLASS % class_id + metric_ops[head_lib._summary_key(self._name, auc_pr_key)] = ( # pylint:disable=protected-access + head_lib._auc( # pylint:disable=protected-access + labels=class_labels, + predictions=class_probabilities, + weights=weights, + curve='PR', + name=auc_pr_key)) return metric_ops diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index 19b86df5565a85168bdbc37076a0af69248a8010..d6c158608b5c564f24bc90583084306aa7084742 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -175,6 +175,21 @@ class MultiLabelHead(test.TestCase): r'loss_fn has unexpected args: \[\'name\'\]'): head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn) + def test_classes_for_class_based_metrics_invalid(self): + with self.assertRaisesRegexp( + ValueError, + r'All classes_for_class_based_metrics must be in range \[0, 2\]\. ' + r'Given: -1'): + head_lib.multi_label_head( + n_classes=3, classes_for_class_based_metrics=[2, -1]) + + def test_classes_for_class_based_metrics_string_invalid(self): + with self.assertRaisesRegexp( + ValueError, r'\'z\' is not in list'): + head_lib.multi_label_head( + n_classes=3, label_vocabulary=['a', 'b', 'c'], + classes_for_class_based_metrics=['c', 'z']) + def test_name(self): head = head_lib.multi_label_head(n_classes=4, name='foo') self.assertEqual('foo', head.name) @@ -591,6 +606,81 @@ class MultiLabelHead(test.TestCase): expected_loss=expected_loss, expected_metrics=expected_metrics) + def test_eval_with_classes_for_class_based_metrics(self): + head = head_lib.multi_label_head( + n_classes=2, classes_for_class_based_metrics=[0, 1]) + + logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + # loss = labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits)) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels, logits=logits)) + + keys = metric_keys.MetricKeys + expected_metrics = { + # Average loss over examples. + keys.LOSS_MEAN: expected_loss, + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.3333, + keys.AUC_PR: 0.7639, + keys.PROBABILITY_MEAN_AT_CLASS % 0: np.sum(_sigmoid(logits[:, 0])) / 2., + keys.AUC_AT_CLASS % 0: 0., + keys.AUC_PR_AT_CLASS % 0: 1., + keys.PROBABILITY_MEAN_AT_CLASS % 1: np.sum(_sigmoid(logits[:, 1])) / 2., + keys.AUC_AT_CLASS % 1: 1., + keys.AUC_PR_AT_CLASS % 1: 1., + } + + self._test_eval( + head=head, + logits=logits, + labels=labels, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + + def test_eval_with_classes_for_class_based_metrics_string(self): + head = head_lib.multi_label_head( + n_classes=2, label_vocabulary=['a', 'b'], + classes_for_class_based_metrics=['a', 'b']) + + logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) + labels = sparse_tensor.SparseTensor( + values=['a', 'a', 'b'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + labels_onehot = np.array([[1, 0], [1, 1]], dtype=np.int64) + # loss = labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits)) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels_onehot, logits=logits)) + + keys = metric_keys.MetricKeys + expected_metrics = { + # Average loss over examples. + keys.LOSS_MEAN: expected_loss, + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.3333, + keys.AUC_PR: 0.7639, + keys.PROBABILITY_MEAN_AT_CLASS % 0: np.sum(_sigmoid(logits[:, 0])) / 2., + keys.AUC_AT_CLASS % 0: 0., + keys.AUC_PR_AT_CLASS % 0: 1., + keys.PROBABILITY_MEAN_AT_CLASS % 1: np.sum(_sigmoid(logits[:, 1])) / 2., + keys.AUC_AT_CLASS % 1: 1., + keys.AUC_PR_AT_CLASS % 1: 1., + } + + self._test_eval( + head=head, + logits=logits, + labels=labels, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + def test_eval_with_weights(self): n_classes = 2 head = head_lib.multi_label_head(n_classes, weight_column='example_weights') diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..4808b9ee30e10047aaf3d33f74457b2717c87a13 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/hooks.py @@ -0,0 +1,213 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Some useful session run hooks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import training + + +# pylint: disable=protected-access +class InMemoryEvaluatorHook(training.SessionRunHook): + """Hook to run evaluation in training without a checkpoint. + + Example: + + ```python + def train_input_fn(): + ... + return train_dataset + + def eval_input_fn(): + ... + return eval_dataset + + estimator = tf.estimator.DNNClassifier(...) + + evaluator = tf.contrib.estimator.InMemoryEvaluatorHook( + estimator, eval_input_fn) + estimator.train(train_input_fn, hooks=[evaluator]) + ``` + + Current limitations of this approach are: + * It doesn't support multi-node distributed mode. + * It doesn't support saveable objects other than variables (such as boosted + tree support) + * It doesn't support custom saver logic (such as ExponentialMovingAverage + support) + + """ + + def __init__(self, + estimator, + input_fn, + steps=None, + hooks=None, + name=None, + every_n_iter=100): + """Initializes a `InMemoryEvaluatorHook`. + + Args: + estimator: A `tf.estimator.Estimator` instance to call evaluate. + input_fn: Equivalent to the `input_fn` arg to `estimator.evaluate`. A + function that constructs the input data for evaluation. + See @{$get_started/premade_estimators#create_input_functions} for more + information. The function should construct and return one of + the following: + + * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a + tuple (features, labels) with same constraints as below. + * A tuple (features, labels): Where `features` is a `Tensor` or a + dictionary of string feature name to `Tensor` and `labels` is a + `Tensor` or a dictionary of string label name to `Tensor`. Both + `features` and `labels` are consumed by `model_fn`. They should + satisfy the expectation of `model_fn` from inputs. + + steps: Equivalent to the `steps` arg to `estimator.evaluate`. Number of + steps for which to evaluate model. If `None`, evaluates until `input_fn` + raises an end-of-input exception. + hooks: Equivalent to the `hooks` arg to `estimator.evaluate`. List of + `SessionRunHook` subclass instances. Used for callbacks inside the + evaluation call. + name: Equivalent to the `name` arg to `estimator.evaluate`. Name of the + evaluation if user needs to run multiple evaluations on different data + sets, such as on training data vs test data. Metrics for different + evaluations are saved in separate folders, and appear separately in + tensorboard. + every_n_iter: `int`, runs the evaluator once every N training iteration. + + Raises: + ValueError: if `every_n_iter` is non-positive or it's not a single machine + training + """ + if every_n_iter is None or every_n_iter <= 0: + raise ValueError('invalid every_n_iter=%s.' % every_n_iter) + if (estimator.config.num_ps_replicas > 0 or + estimator.config.num_worker_replicas > 1): + raise ValueError( + 'InMemoryEvaluator supports only single machine (aka Local) setting.') + self._estimator = estimator + self._input_fn = input_fn + self._steps = steps + self._name = name + self._every_n_iter = every_n_iter + self._eval_dir = os.path.join(self._estimator.model_dir, 'eval' + if not name else 'eval_' + name) + + self._graph = None + self._hooks = estimator_lib._check_hooks_type(hooks) + self._hooks.extend(self._estimator._convert_eval_steps_to_hooks(steps)) + self._timer = training.SecondOrStepTimer(every_steps=every_n_iter) + + def begin(self): + """Build eval graph and restoring op.""" + self._timer.reset() + self._iter_count = 0 + self._graph = ops.Graph() + with self._graph.as_default(): + (self._scaffold, self._update_op, self._eval_dict, + self._all_hooks) = self._estimator._evaluate_build_graph( + self._input_fn, self._hooks, checkpoint_path=None) + + if self._scaffold.saver is not None: + raise ValueError('InMemoryEvaluator does not support custom saver') + if self._scaffold.init_fn is not None: + raise ValueError('InMemoryEvaluator does not support custom init_fn') + + self._var_name_to_eval_var = { + v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + } + self._var_name_to_placeholder = { + v.name: array_ops.placeholder(v.dtype) + for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + } + + def after_create_session(self, session, coord): # pylint: disable=unused-argument + """Does first run which shows the eval metrics before training.""" + if ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS): + raise ValueError( + 'InMemoryEvaluator does not support saveables other than global ' + 'variables.') + self._var_name_to_train_var = { + v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + } + var_names_to_transfer = set(self._var_name_to_placeholder.keys()) & set( + self._var_name_to_train_var.keys()) + # Filter training var names that are not exist in evaluation + self._var_name_to_train_var = { + v_name: self._var_name_to_train_var[v_name] + for v_name in var_names_to_transfer + } + # Filter eval var names that are not exist in training + self._var_name_to_eval_var = { + v_name: self._var_name_to_eval_var[v_name] + for v_name in var_names_to_transfer + } + + with self._graph.as_default(): + self._var_feed_op = control_flow_ops.group([ + state_ops.assign(self._var_name_to_eval_var[v_name], + self._var_name_to_placeholder[v_name]) + for v_name in var_names_to_transfer + ]) + + self._evaluate(session) + + def _evaluate(self, train_session): + var_name_to_value = train_session.run(self._var_name_to_train_var) + placeholder_to_value = { + self._var_name_to_placeholder[v_name]: var_name_to_value[v_name] + for v_name in var_name_to_value + } + + def feed_variables(scaffold, session): + del scaffold + session.run(self._var_feed_op, feed_dict=placeholder_to_value) + + scaffold = training.Scaffold( + init_fn=feed_variables, copy_from_scaffold=self._scaffold) + + with self._graph.as_default(): + return self._estimator._evaluate_run( + checkpoint_path=None, + scaffold=scaffold, + update_op=self._update_op, + eval_dict=self._eval_dict, + all_hooks=self._all_hooks, + output_dir=self._eval_dir) + + self._timer.update_last_triggered_step(self._iter_count) + + def after_run(self, run_context, run_values): # pylint: disable=unused-argument + """Runs evaluator.""" + self._iter_count += 1 + if self._timer.should_trigger_for_step(self._iter_count): + self._evaluate(run_context.session) + + def end(self, session): # pylint: disable=unused-argument + """Runs evaluator for final model.""" + self._evaluate(session) + + +# pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py new file mode 100644 index 0000000000000000000000000000000000000000..95ae971852ee6dffb6174fc243686721c30ef685 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py @@ -0,0 +1,318 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for hooks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import json +import os + +from tensorflow.contrib.estimator.python.estimator import hooks as hooks_lib +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import estimator_lib +from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.summary import summary_iterator +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import training + + +def summary_step_keyword_to_value_mapping(dir_): + writer_cache.FileWriterCache.clear() + + # Get last Event written. + event_paths = glob.glob(os.path.join(dir_, 'events*')) + step_keyword_to_value = {} + for last_event in summary_iterator.summary_iterator(event_paths[-1]): + if last_event.step not in step_keyword_to_value: + step_keyword_to_value[last_event.step] = {} + if last_event.summary is not None: + for value in last_event.summary.value: + step_keyword_to_value[last_event.step][value.tag] = value.simple_value + + return step_keyword_to_value + + +def get_summary_value(dir_, step, keyword): + """Get summary value for given step and keyword.""" + + writer_cache.FileWriterCache.clear() + # Get last Event written. + event_paths = glob.glob(os.path.join(dir_, 'events*')) + print('XXX', event_paths) + for last_event in summary_iterator.summary_iterator(event_paths[-1]): + if last_event.step == step and last_event.summary is not None: + for value in last_event.summary.value: + if keyword in value.tag: + return value.simple_value + return None + + +class InMemoryEvaluatorHookTest(test.TestCase): + + def test_runs_eval_metrics(self): + + def model_fn(features, labels, mode): + _ = labels + if estimator_lib.ModeKeys.TRAIN == mode: + with ops.control_dependencies([features]): + train_op = state_ops.assign_add(training.get_global_step(), 1) + return estimator_lib.EstimatorSpec( + mode, loss=constant_op.constant(3.), train_op=train_op) + if estimator_lib.ModeKeys.EVAL == mode: + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(5.), + eval_metric_ops={'mean_of_features': metrics_lib.mean(features)}) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook( + estimator, input_fn, every_n_iter=4) + estimator.train(input_fn, hooks=[evaluator]) + + self.assertTrue(os.path.isdir(estimator.eval_dir())) + step_keyword_to_value = summary_step_keyword_to_value_mapping( + estimator.eval_dir()) + # 4.5 = sum(range(10))/10 + # before training + self.assertEqual(4.5, step_keyword_to_value[0]['mean_of_features']) + # intervals (every_n_iter=4) + self.assertEqual(4.5, step_keyword_to_value[4]['mean_of_features']) + self.assertEqual(4.5, step_keyword_to_value[8]['mean_of_features']) + # end + self.assertEqual(4.5, step_keyword_to_value[10]['mean_of_features']) + + def test_uses_latest_variable_value(self): + + def model_fn(features, labels, mode): + _ = labels + step = training.get_global_step() + w = variable_scope.get_variable( + 'w', + shape=[], + initializer=init_ops.zeros_initializer(), + dtype=dtypes.int64) + if estimator_lib.ModeKeys.TRAIN == mode: + # to consume features, we have control dependency + with ops.control_dependencies([features]): + step_inc = state_ops.assign_add(training.get_global_step(), 1) + with ops.control_dependencies([step_inc]): + assign_w_to_step_plus_2 = w.assign(step + 2) + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + train_op=assign_w_to_step_plus_2) + if estimator_lib.ModeKeys.EVAL == mode: + # to consume features, we have control dependency + with ops.control_dependencies([features]): + loss = constant_op.constant(5.) + return estimator_lib.EstimatorSpec( + mode, + loss=loss, + # w is constant in each step, so the mean. + # w = 0 if step==0 else step+2 + eval_metric_ops={'mean_of_const': metrics_lib.mean(w)}) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook( + estimator, input_fn, every_n_iter=4) + estimator.train(input_fn, hooks=[evaluator]) + + self.assertTrue(os.path.isdir(estimator.eval_dir())) + step_keyword_to_value = summary_step_keyword_to_value_mapping( + estimator.eval_dir()) + # w = 0 if step==0 else step+2 + self.assertEqual(0, step_keyword_to_value[0]['mean_of_const']) + self.assertEqual(6, step_keyword_to_value[4]['mean_of_const']) + self.assertEqual(12, step_keyword_to_value[10]['mean_of_const']) + + def test_dnn_classifier(self): + embedding = feature_column_lib.embedding_column( + feature_column_lib.categorical_column_with_vocabulary_list( + 'wire_cast', ['kima', 'omar', 'stringer']), 8) + dnn = estimator_lib.DNNClassifier( + feature_columns=[embedding], hidden_units=[3, 1]) + + def train_input_fn(): + return dataset_ops.Dataset.from_tensors(({ + 'wire_cast': [['omar'], ['kima']] + }, [[0], [1]])).repeat(3) + + def eval_input_fn(): + return dataset_ops.Dataset.from_tensors(({ + 'wire_cast': [['stringer'], ['kima']] + }, [[0], [1]])).repeat(2) + + evaluator = hooks_lib.InMemoryEvaluatorHook( + dnn, eval_input_fn, name='in-memory') + dnn.train(train_input_fn, hooks=[evaluator]) + self.assertTrue(os.path.isdir(dnn.eval_dir('in-memory'))) + step_keyword_to_value = summary_step_keyword_to_value_mapping( + dnn.eval_dir('in-memory')) + + final_metrics = dnn.evaluate(eval_input_fn) + step = final_metrics[ops.GraphKeys.GLOBAL_STEP] + for summary_tag in final_metrics: + if summary_tag == ops.GraphKeys.GLOBAL_STEP: + continue + self.assertEqual(final_metrics[summary_tag], + step_keyword_to_value[step][summary_tag]) + + def test_raise_error_with_multi_worker(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.CHIEF, + 'index': 0 + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + dnn = estimator_lib.DNNClassifier( + feature_columns=[feature_column_lib.numeric_column('x')], + hidden_units=[3, 1]) + + def eval_input_fn(): + pass + + with self.assertRaisesRegexp(ValueError, 'supports only single machine'): + hooks_lib.InMemoryEvaluatorHook(dnn, eval_input_fn) + + def test_raise_error_with_ps(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1'], + }, + 'task': { + 'type': run_config_lib.TaskType.CHIEF, + 'index': 0 + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + dnn = estimator_lib.DNNClassifier( + feature_columns=[feature_column_lib.numeric_column('x')], + hidden_units=[3, 1]) + + def eval_input_fn(): + pass + + with self.assertRaisesRegexp(ValueError, 'supports only single machine'): + hooks_lib.InMemoryEvaluatorHook(dnn, eval_input_fn) + + def test_raise_error_with_custom_saver_in_eval(self): + + def model_fn(features, labels, mode): + _, _ = features, labels + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + scaffold=training.Scaffold(saver=training.Saver()), + train_op=constant_op.constant(5.), + eval_metric_ops={ + 'mean_of_features': metrics_lib.mean(constant_op.constant(2.)) + }) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn) + with self.assertRaisesRegexp(ValueError, 'does not support custom saver'): + evaluator.begin() + + def test_raise_error_with_custom_init_fn_in_eval(self): + + def model_fn(features, labels, mode): + _, _ = features, labels + + def init_fn(scaffold, session): + _, _ = scaffold, session + + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + scaffold=training.Scaffold(init_fn=init_fn), + train_op=constant_op.constant(5.), + eval_metric_ops={ + 'mean_of_features': metrics_lib.mean(constant_op.constant(2.)) + }) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn) + with self.assertRaisesRegexp(ValueError, 'does not support custom init_fn'): + evaluator.begin() + + def test_raise_error_with_saveables_other_than_global_variables(self): + + def model_fn(features, labels, mode): + _, _ = features, labels + w = variables.Variable( + initial_value=[0.], + trainable=False, + collections=[ops.GraphKeys.SAVEABLE_OBJECTS]) + init_op = control_flow_ops.group( + [w.initializer, training.get_global_step().initializer]) + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + scaffold=training.Scaffold(init_op=init_op), + train_op=constant_op.constant(5.), + eval_metric_ops={ + 'mean_of_features': metrics_lib.mean(constant_op.constant(2.)) + }) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn) + with self.assertRaisesRegexp(ValueError, 'does not support saveables'): + estimator.train(input_fn, hooks=[evaluator]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/logit_fns.py b/tensorflow/contrib/estimator/python/estimator/logit_fns.py index 09c2862ccd3f90de4153a2095afc9c3d3f9476c1..c8b0dd62970e341a3c6b176278fe1c2adfcd8d20 100644 --- a/tensorflow/contrib/estimator/python/estimator/logit_fns.py +++ b/tensorflow/contrib/estimator/python/estimator/logit_fns.py @@ -41,10 +41,10 @@ from __future__ import print_function import six -from tensorflow.python.estimator import util from tensorflow.python.estimator.canned import dnn as dnn_core from tensorflow.python.estimator.canned import linear as linear_core from tensorflow.python.framework import ops +from tensorflow.python.util import function_utils # pylint: disable=protected-access dnn_logit_fn_builder = dnn_core._dnn_logit_fn_builder @@ -72,7 +72,7 @@ def call_logit_fn(logit_fn, features, mode, params, config): ValueError: if logit_fn does not return a Tensor or a dictionary mapping strings to Tensors. """ - logit_fn_args = util.fn_args(logit_fn) + logit_fn_args = function_utils.fn_args(logit_fn) kwargs = {} if 'mode' in logit_fn_args: kwargs['mode'] = mode diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index f8564446e5da3e785b85010998d18dca0424d16b..cda23aa437f954700b74dcb9294550eb9a8a8c5c 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -32,7 +32,6 @@ import six from tensorflow.core.framework import node_def_pb2 from tensorflow.python.client import device_lib from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator import util from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import device as framework_device from tensorflow.python.framework import ops as ops_lib @@ -48,6 +47,7 @@ from tensorflow.python.platform import tf_logging from tensorflow.python.training import device_setter as device_setter_lib from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.util import deprecation +from tensorflow.python.util import function_utils @deprecation.deprecated( @@ -521,7 +521,7 @@ def _get_loss_towers(model_fn, """Replicate the loss computation across devices.""" tower_specs = [] - model_fn_args = util.fn_args(model_fn) + model_fn_args = function_utils.fn_args(model_fn) optional_params = {} if 'params' in model_fn_args: optional_params['params'] = copy.deepcopy(params) diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py index 7f385fd76e88aba46f45d16198d707bf1d1e0d8a..7c49cd00d16777872ad1211dfa1d1a3ac9ac1cee 100644 --- a/tensorflow/contrib/estimator/python/estimator/rnn.py +++ b/tensorflow/contrib/estimator/python/estimator/rnn.py @@ -229,6 +229,7 @@ def _rnn_logit_fn_builder(output_units, rnn_cell_fn, sequence_feature_columns, rnn_outputs, _ = rnn.dynamic_rnn( cell=cell, inputs=sequence_input, + sequence_length=sequence_length, dtype=dtypes.float32, time_major=False) last_activations = _select_last_activations(rnn_outputs, sequence_length) diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py index 8fc4f60492b0bfb22ea78cb7b5906e452bb6da58..af1b404cb51bf5d8f8350481f2301d9653895e85 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py @@ -78,7 +78,6 @@ class AssertScalarIntTest(test.TestCase): [3, 4], dtype=dtypes.int32)) -@test_util.with_c_api class WithShapeTest(test.TestCase): def _assert_with_shape(self, tensor, expected_value, expected_shape, @@ -216,25 +215,18 @@ class WithShapeTest(test.TestCase): tensor_partial_shape.set_shape([None, 2]) for incompatible_shape in [[0], [1]]: - if ops._USE_C_API: - error_message = "Shapes must be equal rank, but are 2 and 1" - else: - error_message = r"Shapes \(\?, 2\) and \([01],\) are not compatible" self.assertRaisesRegexp( - ValueError, error_message, + ValueError, "Shapes must be equal rank, but are 2 and 1", tensor_util.with_shape, incompatible_shape, tensor_partial_shape) for incompatible_shape in [[1, 2, 1]]: self.assertRaisesRegexp(ValueError, "Dimensions must be equal", tensor_util.with_shape, incompatible_shape, tensor_partial_shape) for incompatible_shape in [[2, 1]]: - if ops._USE_C_API: - error_message = (r"Dimension 1 in both shapes must be equal, but are " - r"2 and 1. Shapes are \[\?,2\] and \[2,1\].") - else: - error_message = r"Shapes \(\?, 2\) and \(2, 1\) are not compatible" self.assertRaisesRegexp( - ValueError, error_message, + ValueError, + r"Dimension 1 in both shapes must be equal, but are 2 and 1. " + r"Shapes are \[\?,2\] and \[2,1\].", tensor_util.with_shape, incompatible_shape, tensor_partial_shape) compatible_shape = [2, 2] diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index e3fc6bf0f034051fc33ff5966e2f4ea85aa538db..4092b320042162e4eb4c5f4879c2c3ea5dc14fc9 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -112,6 +112,7 @@ class GANEstimator(estimator.Estimator): generator_optimizer=None, discriminator_optimizer=None, get_hooks_fn=None, + get_eval_metric_ops_fn=None, add_summaries=None, use_loss_summaries=True, config=None): @@ -146,6 +147,9 @@ class GANEstimator(estimator.Estimator): list of hooks. These hooks are run on the generator and discriminator train ops, and can be used to implement the GAN training scheme. Defaults to `train.get_sequential_train_hooks()`. + get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a + dict of metric results keyed by name. The output of this function is + passed into `tf.estimator.EstimatorSpec` during evaluation. add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. @@ -160,7 +164,8 @@ class GANEstimator(estimator.Estimator): else discriminator_optimizer) gan_head = head_lib.gan_head( generator_loss_fn, discriminator_loss_fn, gopt, dopt, - use_loss_summaries, get_hooks_fn=get_hooks_fn) + use_loss_summaries, get_hooks_fn=get_hooks_fn, + get_eval_metric_ops_fn=get_eval_metric_ops_fn) return _gan_model_fn( features, labels, mode, generator_fn, discriminator_fn, gan_head, add_summaries) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 387a62bd741bd42c03dc1bf70592060c29ccd7a8..955482599b372be3f0d0cbc81451c514958d0eb1 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -38,6 +38,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -194,6 +195,12 @@ class GANEstimatorIntegrationTest(test.TestCase): lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) return training.GradientDescentOptimizer(lr) + def get_metrics(gan_model): + return { + 'mse_custom_metric': metrics_lib.mean_squared_error( + gan_model.real_data, gan_model.generated_data) + } + gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) est = estimator.GANEstimator( @@ -203,6 +210,7 @@ class GANEstimatorIntegrationTest(test.TestCase): discriminator_loss_fn=losses.wasserstein_discriminator_loss, generator_optimizer=gopt, discriminator_optimizer=dopt, + get_eval_metric_ops_fn=get_metrics, model_dir=self._model_dir) # TRAIN @@ -213,6 +221,9 @@ class GANEstimatorIntegrationTest(test.TestCase): scores = est.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) self.assertIn('loss', six.iterkeys(scores)) + self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], + scores['loss']) + self.assertIn('mse_custom_metric', six.iterkeys(scores)) # PREDICT predictions = np.array([x for x in est.predict(predict_input_fn)]) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index a21358c50bbdb4a1a929b0c5bc322cec4c9923b5..ff903a78cc36c1965b7655aa902501b1943637a8 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -25,17 +25,21 @@ from tensorflow.contrib.gan.python import train as tfgan_train from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.canned import head from tensorflow.python.framework import ops +from tensorflow.python.ops import metrics as metrics_lib __all__ = [ 'GANHead', 'gan_head', ] +def _summary_key(head_name, val): + return '%s/%s' % (val, head_name) if head_name else val + def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, get_hooks_fn=tfgan_train.get_sequential_train_hooks(), - name=None): + get_eval_metric_ops_fn=None, name=None): """Creates a `GANHead`. Args: @@ -47,9 +51,12 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer: Same as `generator_optimizer`, but for the discriminator updates. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. - If `None`, uses defaults. - get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list - of hooks. + If `None`, uses defaults. + get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a + list of hooks. + get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a + dict of metric results keyed by name. The output of this function is + passed into `tf.estimator.EstimatorSpec` during evaluation. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. @@ -62,6 +69,7 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer=discriminator_optimizer, use_loss_summaries=use_loss_summaries, get_hooks_fn=get_hooks_fn, + get_eval_metric_ops_fn=get_eval_metric_ops_fn, name=name) @@ -72,6 +80,7 @@ class GANHead(head._Head): # pylint: disable=protected-access generator_optimizer, discriminator_optimizer, use_loss_summaries=True, get_hooks_fn=None, + get_eval_metric_ops_fn=None, name=None): """`Head` for GAN training. @@ -85,8 +94,11 @@ class GANHead(head._Head): # pylint: disable=protected-access discriminator updates. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. - get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list - of hooks. Defaults to `train.get_sequential_train_hooks()` + get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a + list of hooks. Defaults to `train.get_sequential_train_hooks()` + get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a + dict of metric results keyed by name. The output of this function is + passed into `tf.estimator.EstimatorSpec` during evaluation. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. """ @@ -104,6 +116,8 @@ class GANHead(head._Head): # pylint: disable=protected-access self._generator_optimizer = generator_optimizer self._discriminator_optimizer = discriminator_optimizer self._get_hooks_fn = get_hooks_fn + self._get_eval_metric_ops_fn = get_eval_metric_ops_fn + self._name = name @property def name(self): @@ -173,13 +187,26 @@ class GANHead(head._Head): # pylint: disable=protected-access gan_loss = self.create_loss( features=None, mode=mode, logits=gan_model, labels=None) scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + with ops.name_scope(None, 'metrics', + [gan_loss.generator_loss, + gan_loss.discriminator_loss]): + eval_metric_ops = { + _summary_key(self._name, 'generator_loss'): + metrics_lib.mean(gan_loss.generator_loss), + _summary_key(self._name, 'discriminator_loss'): + metrics_lib.mean(gan_loss.discriminator_loss) + } + if self._get_eval_metric_ops_fn is not None: + custom_eval_metric_ops = self._get_eval_metric_ops_fn(gan_model) + if not isinstance(custom_eval_metric_ops, dict): + raise TypeError('get_eval_metric_ops_fn must return a dict, ' + 'received: {}'.format(custom_eval_metric_ops)) + eval_metric_ops.update(custom_eval_metric_ops) return model_fn_lib.EstimatorSpec( mode=model_fn_lib.ModeKeys.EVAL, predictions=gan_model.generated_data, loss=scalar_loss, - # TODO(joelshor): Add metrics. If head name provided, append it to - # metric keys. - eval_metric_ops={}) + eval_metric_ops=eval_metric_ops) elif mode == model_fn_lib.ModeKeys.TRAIN: if train_op_fn is None: raise ValueError('train_op_fn can not be None.') diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py index 8168f005cd1105886390a2384a936663c83fa5f5..6587f1fc600b94d27f7c12b44ca2136d0be5a8c5 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py @@ -62,9 +62,14 @@ class GANHeadTest(test.TestCase): generator_loss_fn=dummy_loss, discriminator_loss_fn=dummy_loss, generator_optimizer=training.GradientDescentOptimizer(1.0), - discriminator_optimizer=training.GradientDescentOptimizer(1.0)) + discriminator_optimizer=training.GradientDescentOptimizer(1.0), + get_eval_metric_ops_fn=self.get_metrics) self.assertTrue(isinstance(self.gan_head, head.GANHead)) + def get_metrics(self, gan_model): + self.assertTrue(isinstance(gan_model, tfgan_tuples.GANModel)) + return {} + def _test_modes_helper(self, mode): self.gan_head.create_estimator_spec( features=None, diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index 2889e937436d2faa66b5693c19046e122cbaf652..9f5fee45422e0b9bcbc73674e55ae395ea8533d5 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -570,7 +570,7 @@ class MutualInformationPenaltyTest(test.TestCase, _PenaltyTest): 'predicted_distributions': self._predicted_distributions, } self._expected_loss = 1.61610 - self._expected_op_name = 'mutual_information_loss/mul' + self._expected_op_name = 'mutual_information_loss/mul_1' self._batch_size = 2 diff --git a/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c b/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c index 6a5d982dc8514d69277b8f042ac1256e28715d9e..2e5c84704f8464ab46d740ea3c1eef0548826e8d 100644 --- a/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c +++ b/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c @@ -19,7 +19,7 @@ limitations under the License. #include "hexagon_controller.h" -#include +#include #include #include "adspmsgd.h" diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc index 60281951dda94008cad3a164be67d6fe8b59a916..66939fbb0f0d3bb5d2181e38428c038f661d3772 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc @@ -115,7 +115,7 @@ static void CheckOpsSupport(const GraphDef& graph_def, HexagonOpsDefinitions::getInstance(); LOG(INFO) << "Checking " << graph_def.node_size() << " nodes"; LOG(INFO) << "dump_all_nodes = " << dump_all_nodes - << ", dump_shape_and_tpye = " << dump_shape_and_type; + << ", dump_shape_and_type = " << dump_shape_and_type; std::unordered_set unsupported_ops; bool all_supported = true; diff --git a/tensorflow/contrib/keras/api/keras/activations/__init__.py b/tensorflow/contrib/keras/api/keras/activations/__init__.py index d04838c218d6643a703723a1d163c88547c14da7..3f0184276f6b903be63f7b35459e4ad57044eb2c 100644 --- a/tensorflow/contrib/keras/api/keras/activations/__init__.py +++ b/tensorflow/contrib/keras/api/keras/activations/__init__.py @@ -19,22 +19,22 @@ from __future__ import division from __future__ import print_function # Activation functions. -from tensorflow.python.keras._impl.keras.activations import elu -from tensorflow.python.keras._impl.keras.activations import hard_sigmoid -from tensorflow.python.keras._impl.keras.activations import linear -from tensorflow.python.keras._impl.keras.activations import relu -from tensorflow.python.keras._impl.keras.activations import selu -from tensorflow.python.keras._impl.keras.activations import sigmoid -from tensorflow.python.keras._impl.keras.activations import softmax -from tensorflow.python.keras._impl.keras.activations import softplus -from tensorflow.python.keras._impl.keras.activations import softsign -from tensorflow.python.keras._impl.keras.activations import tanh +from tensorflow.python.keras.activations import elu +from tensorflow.python.keras.activations import hard_sigmoid +from tensorflow.python.keras.activations import linear +from tensorflow.python.keras.activations import relu +from tensorflow.python.keras.activations import selu +from tensorflow.python.keras.activations import sigmoid +from tensorflow.python.keras.activations import softmax +from tensorflow.python.keras.activations import softplus +from tensorflow.python.keras.activations import softsign +from tensorflow.python.keras.activations import tanh # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.activations import deserialize -from tensorflow.python.keras._impl.keras.activations import serialize -from tensorflow.python.keras._impl.keras.activations import get +from tensorflow.python.keras.activations import deserialize +from tensorflow.python.keras.activations import serialize +from tensorflow.python.keras.activations import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py b/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py index abf8393ae45d71dc0cb746706abb72f77b82d199..6dfb5cab17c088bfab8ed806adeabd793ced4d12 100644 --- a/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.inception_v3 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.inception_v3 import InceptionV3 -from tensorflow.python.keras._impl.keras.applications.inception_v3 import preprocess_input +from tensorflow.python.keras.applications.inception_v3 import decode_predictions +from tensorflow.python.keras.applications.inception_v3 import InceptionV3 +from tensorflow.python.keras.applications.inception_v3 import preprocess_input del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py b/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py index b809e91193b459a46906443796344c092e1d2a6b..67306cc51e1927cfbc2db424b1f4165dabfa22f9 100644 --- a/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.mobilenet import decode_predictions -from tensorflow.python.keras._impl.keras.applications.mobilenet import MobileNet -from tensorflow.python.keras._impl.keras.applications.mobilenet import preprocess_input +from tensorflow.python.keras.applications.mobilenet import decode_predictions +from tensorflow.python.keras.applications.mobilenet import MobileNet +from tensorflow.python.keras.applications.mobilenet import preprocess_input del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py b/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py index 530805d150bfe32c5b81d7d7d3f92e203b83b602..a25ff48b593a9a9ea56fd427a932bb64c10f7b7b 100644 --- a/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.resnet50 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.resnet50 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.resnet50 import ResNet50 +from tensorflow.python.keras.applications.resnet50 import decode_predictions +from tensorflow.python.keras.applications.resnet50 import preprocess_input +from tensorflow.python.keras.applications.resnet50 import ResNet50 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py b/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py index 118361604bbc7e0a88ed34243c0d5ea98856a301..4964b1b7deb56fe0025e9a8d8cb45d18e0209fea 100644 --- a/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.vgg16 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.vgg16 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.vgg16 import VGG16 +from tensorflow.python.keras.applications.vgg16 import decode_predictions +from tensorflow.python.keras.applications.vgg16 import preprocess_input +from tensorflow.python.keras.applications.vgg16 import VGG16 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py b/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py index cda52628f3c10d65fdbe70b2f86cc12c771870a9..afb3abebdd6735e6f17bc94c1fcd15a31b74f983 100644 --- a/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.vgg19 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.vgg19 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.vgg19 import VGG19 +from tensorflow.python.keras.applications.vgg19 import decode_predictions +from tensorflow.python.keras.applications.vgg19 import preprocess_input +from tensorflow.python.keras.applications.vgg19 import VGG19 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py b/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py index ae9cd9cd18c5ccc5ec37c8cd1bf36f8aabd9929c..2e3335d02aff0fff805fc2dac614b14e0593d40d 100644 --- a/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.xception import decode_predictions -from tensorflow.python.keras._impl.keras.applications.xception import preprocess_input -from tensorflow.python.keras._impl.keras.applications.xception import Xception +from tensorflow.python.keras.applications.xception import decode_predictions +from tensorflow.python.keras.applications.xception import preprocess_input +from tensorflow.python.keras.applications.xception import Xception del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/backend/__init__.py b/tensorflow/contrib/keras/api/keras/backend/__init__.py index 10ef5a75852deb6595bced2703d7c5f29b0efac3..a755364014206e92289eec0b9c8e510251862e0e 100644 --- a/tensorflow/contrib/keras/api/keras/backend/__init__.py +++ b/tensorflow/contrib/keras/api/keras/backend/__init__.py @@ -19,144 +19,144 @@ from __future__ import division from __future__ import print_function # pylint: disable=redefined-builtin -from tensorflow.python.keras._impl.keras.backend import abs -from tensorflow.python.keras._impl.keras.backend import all -from tensorflow.python.keras._impl.keras.backend import any -from tensorflow.python.keras._impl.keras.backend import arange -from tensorflow.python.keras._impl.keras.backend import argmax -from tensorflow.python.keras._impl.keras.backend import argmin -from tensorflow.python.keras._impl.keras.backend import backend -from tensorflow.python.keras._impl.keras.backend import batch_dot -from tensorflow.python.keras._impl.keras.backend import batch_flatten -from tensorflow.python.keras._impl.keras.backend import batch_get_value -from tensorflow.python.keras._impl.keras.backend import batch_normalization -from tensorflow.python.keras._impl.keras.backend import batch_set_value -from tensorflow.python.keras._impl.keras.backend import bias_add -from tensorflow.python.keras._impl.keras.backend import binary_crossentropy -from tensorflow.python.keras._impl.keras.backend import cast -from tensorflow.python.keras._impl.keras.backend import cast_to_floatx -from tensorflow.python.keras._impl.keras.backend import categorical_crossentropy -from tensorflow.python.keras._impl.keras.backend import clear_session -from tensorflow.python.keras._impl.keras.backend import clip -from tensorflow.python.keras._impl.keras.backend import concatenate -from tensorflow.python.keras._impl.keras.backend import constant -from tensorflow.python.keras._impl.keras.backend import conv1d -from tensorflow.python.keras._impl.keras.backend import conv2d -from tensorflow.python.keras._impl.keras.backend import conv2d_transpose -from tensorflow.python.keras._impl.keras.backend import conv3d -from tensorflow.python.keras._impl.keras.backend import cos -from tensorflow.python.keras._impl.keras.backend import count_params -from tensorflow.python.keras._impl.keras.backend import ctc_batch_cost -from tensorflow.python.keras._impl.keras.backend import ctc_decode -from tensorflow.python.keras._impl.keras.backend import ctc_label_dense_to_sparse -from tensorflow.python.keras._impl.keras.backend import dot -from tensorflow.python.keras._impl.keras.backend import dropout -from tensorflow.python.keras._impl.keras.backend import dtype -from tensorflow.python.keras._impl.keras.backend import elu -from tensorflow.python.keras._impl.keras.backend import epsilon -from tensorflow.python.keras._impl.keras.backend import equal -from tensorflow.python.keras._impl.keras.backend import eval -from tensorflow.python.keras._impl.keras.backend import exp -from tensorflow.python.keras._impl.keras.backend import expand_dims -from tensorflow.python.keras._impl.keras.backend import eye -from tensorflow.python.keras._impl.keras.backend import flatten -from tensorflow.python.keras._impl.keras.backend import floatx -from tensorflow.python.keras._impl.keras.backend import foldl -from tensorflow.python.keras._impl.keras.backend import foldr -from tensorflow.python.keras._impl.keras.backend import function -from tensorflow.python.keras._impl.keras.backend import gather -from tensorflow.python.keras._impl.keras.backend import get_session -from tensorflow.python.keras._impl.keras.backend import get_uid -from tensorflow.python.keras._impl.keras.backend import get_value -from tensorflow.python.keras._impl.keras.backend import gradients -from tensorflow.python.keras._impl.keras.backend import greater -from tensorflow.python.keras._impl.keras.backend import greater_equal -from tensorflow.python.keras._impl.keras.backend import hard_sigmoid -from tensorflow.python.keras._impl.keras.backend import image_data_format -from tensorflow.python.keras._impl.keras.backend import in_test_phase -from tensorflow.python.keras._impl.keras.backend import in_top_k -from tensorflow.python.keras._impl.keras.backend import in_train_phase -from tensorflow.python.keras._impl.keras.backend import int_shape -from tensorflow.python.keras._impl.keras.backend import is_sparse -from tensorflow.python.keras._impl.keras.backend import l2_normalize -from tensorflow.python.keras._impl.keras.backend import learning_phase -from tensorflow.python.keras._impl.keras.backend import less -from tensorflow.python.keras._impl.keras.backend import less_equal -from tensorflow.python.keras._impl.keras.backend import log -from tensorflow.python.keras._impl.keras.backend import manual_variable_initialization -from tensorflow.python.keras._impl.keras.backend import map_fn -from tensorflow.python.keras._impl.keras.backend import max -from tensorflow.python.keras._impl.keras.backend import maximum -from tensorflow.python.keras._impl.keras.backend import mean -from tensorflow.python.keras._impl.keras.backend import min -from tensorflow.python.keras._impl.keras.backend import minimum -from tensorflow.python.keras._impl.keras.backend import moving_average_update -from tensorflow.python.keras._impl.keras.backend import name_scope -from tensorflow.python.keras._impl.keras.backend import ndim -from tensorflow.python.keras._impl.keras.backend import normalize_batch_in_training -from tensorflow.python.keras._impl.keras.backend import not_equal -from tensorflow.python.keras._impl.keras.backend import one_hot -from tensorflow.python.keras._impl.keras.backend import ones -from tensorflow.python.keras._impl.keras.backend import ones_like -from tensorflow.python.keras._impl.keras.backend import permute_dimensions -from tensorflow.python.keras._impl.keras.backend import placeholder -from tensorflow.python.keras._impl.keras.backend import pool2d -from tensorflow.python.keras._impl.keras.backend import pool3d -from tensorflow.python.keras._impl.keras.backend import pow -from tensorflow.python.keras._impl.keras.backend import print_tensor -from tensorflow.python.keras._impl.keras.backend import prod -from tensorflow.python.keras._impl.keras.backend import random_binomial -from tensorflow.python.keras._impl.keras.backend import random_normal -from tensorflow.python.keras._impl.keras.backend import random_normal_variable -from tensorflow.python.keras._impl.keras.backend import random_uniform -from tensorflow.python.keras._impl.keras.backend import random_uniform_variable -from tensorflow.python.keras._impl.keras.backend import relu -from tensorflow.python.keras._impl.keras.backend import repeat -from tensorflow.python.keras._impl.keras.backend import repeat_elements -from tensorflow.python.keras._impl.keras.backend import reset_uids -from tensorflow.python.keras._impl.keras.backend import reshape -from tensorflow.python.keras._impl.keras.backend import resize_images -from tensorflow.python.keras._impl.keras.backend import resize_volumes -from tensorflow.python.keras._impl.keras.backend import reverse -from tensorflow.python.keras._impl.keras.backend import rnn -from tensorflow.python.keras._impl.keras.backend import round -from tensorflow.python.keras._impl.keras.backend import separable_conv2d -from tensorflow.python.keras._impl.keras.backend import set_epsilon -from tensorflow.python.keras._impl.keras.backend import set_floatx -from tensorflow.python.keras._impl.keras.backend import set_image_data_format -from tensorflow.python.keras._impl.keras.backend import set_learning_phase -from tensorflow.python.keras._impl.keras.backend import set_session -from tensorflow.python.keras._impl.keras.backend import set_value -from tensorflow.python.keras._impl.keras.backend import shape -from tensorflow.python.keras._impl.keras.backend import sigmoid -from tensorflow.python.keras._impl.keras.backend import sign -from tensorflow.python.keras._impl.keras.backend import sin -from tensorflow.python.keras._impl.keras.backend import softmax -from tensorflow.python.keras._impl.keras.backend import softplus -from tensorflow.python.keras._impl.keras.backend import softsign -from tensorflow.python.keras._impl.keras.backend import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.backend import spatial_2d_padding -from tensorflow.python.keras._impl.keras.backend import spatial_3d_padding -from tensorflow.python.keras._impl.keras.backend import sqrt -from tensorflow.python.keras._impl.keras.backend import square -from tensorflow.python.keras._impl.keras.backend import squeeze -from tensorflow.python.keras._impl.keras.backend import stack -from tensorflow.python.keras._impl.keras.backend import std -from tensorflow.python.keras._impl.keras.backend import stop_gradient -from tensorflow.python.keras._impl.keras.backend import sum -from tensorflow.python.keras._impl.keras.backend import switch -from tensorflow.python.keras._impl.keras.backend import tanh -from tensorflow.python.keras._impl.keras.backend import temporal_padding -from tensorflow.python.keras._impl.keras.backend import to_dense -from tensorflow.python.keras._impl.keras.backend import transpose -from tensorflow.python.keras._impl.keras.backend import truncated_normal -from tensorflow.python.keras._impl.keras.backend import update -from tensorflow.python.keras._impl.keras.backend import update_add -from tensorflow.python.keras._impl.keras.backend import update_sub -from tensorflow.python.keras._impl.keras.backend import var -from tensorflow.python.keras._impl.keras.backend import variable -from tensorflow.python.keras._impl.keras.backend import zeros -from tensorflow.python.keras._impl.keras.backend import zeros_like +from tensorflow.python.keras.backend import abs +from tensorflow.python.keras.backend import all +from tensorflow.python.keras.backend import any +from tensorflow.python.keras.backend import arange +from tensorflow.python.keras.backend import argmax +from tensorflow.python.keras.backend import argmin +from tensorflow.python.keras.backend import backend +from tensorflow.python.keras.backend import batch_dot +from tensorflow.python.keras.backend import batch_flatten +from tensorflow.python.keras.backend import batch_get_value +from tensorflow.python.keras.backend import batch_normalization +from tensorflow.python.keras.backend import batch_set_value +from tensorflow.python.keras.backend import bias_add +from tensorflow.python.keras.backend import binary_crossentropy +from tensorflow.python.keras.backend import cast +from tensorflow.python.keras.backend import cast_to_floatx +from tensorflow.python.keras.backend import categorical_crossentropy +from tensorflow.python.keras.backend import clear_session +from tensorflow.python.keras.backend import clip +from tensorflow.python.keras.backend import concatenate +from tensorflow.python.keras.backend import constant +from tensorflow.python.keras.backend import conv1d +from tensorflow.python.keras.backend import conv2d +from tensorflow.python.keras.backend import conv2d_transpose +from tensorflow.python.keras.backend import conv3d +from tensorflow.python.keras.backend import cos +from tensorflow.python.keras.backend import count_params +from tensorflow.python.keras.backend import ctc_batch_cost +from tensorflow.python.keras.backend import ctc_decode +from tensorflow.python.keras.backend import ctc_label_dense_to_sparse +from tensorflow.python.keras.backend import dot +from tensorflow.python.keras.backend import dropout +from tensorflow.python.keras.backend import dtype +from tensorflow.python.keras.backend import elu +from tensorflow.python.keras.backend import epsilon +from tensorflow.python.keras.backend import equal +from tensorflow.python.keras.backend import eval +from tensorflow.python.keras.backend import exp +from tensorflow.python.keras.backend import expand_dims +from tensorflow.python.keras.backend import eye +from tensorflow.python.keras.backend import flatten +from tensorflow.python.keras.backend import floatx +from tensorflow.python.keras.backend import foldl +from tensorflow.python.keras.backend import foldr +from tensorflow.python.keras.backend import function +from tensorflow.python.keras.backend import gather +from tensorflow.python.keras.backend import get_session +from tensorflow.python.keras.backend import get_uid +from tensorflow.python.keras.backend import get_value +from tensorflow.python.keras.backend import gradients +from tensorflow.python.keras.backend import greater +from tensorflow.python.keras.backend import greater_equal +from tensorflow.python.keras.backend import hard_sigmoid +from tensorflow.python.keras.backend import image_data_format +from tensorflow.python.keras.backend import in_test_phase +from tensorflow.python.keras.backend import in_top_k +from tensorflow.python.keras.backend import in_train_phase +from tensorflow.python.keras.backend import int_shape +from tensorflow.python.keras.backend import is_sparse +from tensorflow.python.keras.backend import l2_normalize +from tensorflow.python.keras.backend import learning_phase +from tensorflow.python.keras.backend import less +from tensorflow.python.keras.backend import less_equal +from tensorflow.python.keras.backend import log +from tensorflow.python.keras.backend import manual_variable_initialization +from tensorflow.python.keras.backend import map_fn +from tensorflow.python.keras.backend import max +from tensorflow.python.keras.backend import maximum +from tensorflow.python.keras.backend import mean +from tensorflow.python.keras.backend import min +from tensorflow.python.keras.backend import minimum +from tensorflow.python.keras.backend import moving_average_update +from tensorflow.python.keras.backend import name_scope +from tensorflow.python.keras.backend import ndim +from tensorflow.python.keras.backend import normalize_batch_in_training +from tensorflow.python.keras.backend import not_equal +from tensorflow.python.keras.backend import one_hot +from tensorflow.python.keras.backend import ones +from tensorflow.python.keras.backend import ones_like +from tensorflow.python.keras.backend import permute_dimensions +from tensorflow.python.keras.backend import placeholder +from tensorflow.python.keras.backend import pool2d +from tensorflow.python.keras.backend import pool3d +from tensorflow.python.keras.backend import pow +from tensorflow.python.keras.backend import print_tensor +from tensorflow.python.keras.backend import prod +from tensorflow.python.keras.backend import random_binomial +from tensorflow.python.keras.backend import random_normal +from tensorflow.python.keras.backend import random_normal_variable +from tensorflow.python.keras.backend import random_uniform +from tensorflow.python.keras.backend import random_uniform_variable +from tensorflow.python.keras.backend import relu +from tensorflow.python.keras.backend import repeat +from tensorflow.python.keras.backend import repeat_elements +from tensorflow.python.keras.backend import reset_uids +from tensorflow.python.keras.backend import reshape +from tensorflow.python.keras.backend import resize_images +from tensorflow.python.keras.backend import resize_volumes +from tensorflow.python.keras.backend import reverse +from tensorflow.python.keras.backend import rnn +from tensorflow.python.keras.backend import round +from tensorflow.python.keras.backend import separable_conv2d +from tensorflow.python.keras.backend import set_epsilon +from tensorflow.python.keras.backend import set_floatx +from tensorflow.python.keras.backend import set_image_data_format +from tensorflow.python.keras.backend import set_learning_phase +from tensorflow.python.keras.backend import set_session +from tensorflow.python.keras.backend import set_value +from tensorflow.python.keras.backend import shape +from tensorflow.python.keras.backend import sigmoid +from tensorflow.python.keras.backend import sign +from tensorflow.python.keras.backend import sin +from tensorflow.python.keras.backend import softmax +from tensorflow.python.keras.backend import softplus +from tensorflow.python.keras.backend import softsign +from tensorflow.python.keras.backend import sparse_categorical_crossentropy +from tensorflow.python.keras.backend import spatial_2d_padding +from tensorflow.python.keras.backend import spatial_3d_padding +from tensorflow.python.keras.backend import sqrt +from tensorflow.python.keras.backend import square +from tensorflow.python.keras.backend import squeeze +from tensorflow.python.keras.backend import stack +from tensorflow.python.keras.backend import std +from tensorflow.python.keras.backend import stop_gradient +from tensorflow.python.keras.backend import sum +from tensorflow.python.keras.backend import switch +from tensorflow.python.keras.backend import tanh +from tensorflow.python.keras.backend import temporal_padding +from tensorflow.python.keras.backend import to_dense +from tensorflow.python.keras.backend import transpose +from tensorflow.python.keras.backend import truncated_normal +from tensorflow.python.keras.backend import update +from tensorflow.python.keras.backend import update_add +from tensorflow.python.keras.backend import update_sub +from tensorflow.python.keras.backend import var +from tensorflow.python.keras.backend import variable +from tensorflow.python.keras.backend import zeros +from tensorflow.python.keras.backend import zeros_like del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/callbacks/__init__.py b/tensorflow/contrib/keras/api/keras/callbacks/__init__.py index 2d884790ddb9ccf49649c6af4cfd40cddbc38cb3..10e05f2969bc404d4cf3a9b7a999510cd40e3c17 100644 --- a/tensorflow/contrib/keras/api/keras/callbacks/__init__.py +++ b/tensorflow/contrib/keras/api/keras/callbacks/__init__.py @@ -18,19 +18,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.callbacks import BaseLogger -from tensorflow.python.keras._impl.keras.callbacks import Callback -from tensorflow.python.keras._impl.keras.callbacks import CSVLogger -from tensorflow.python.keras._impl.keras.callbacks import EarlyStopping -from tensorflow.python.keras._impl.keras.callbacks import History -from tensorflow.python.keras._impl.keras.callbacks import LambdaCallback -from tensorflow.python.keras._impl.keras.callbacks import LearningRateScheduler -from tensorflow.python.keras._impl.keras.callbacks import ModelCheckpoint -from tensorflow.python.keras._impl.keras.callbacks import ProgbarLogger -from tensorflow.python.keras._impl.keras.callbacks import ReduceLROnPlateau -from tensorflow.python.keras._impl.keras.callbacks import RemoteMonitor -from tensorflow.python.keras._impl.keras.callbacks import TensorBoard -from tensorflow.python.keras._impl.keras.callbacks import TerminateOnNaN +from tensorflow.python.keras.callbacks import BaseLogger +from tensorflow.python.keras.callbacks import Callback +from tensorflow.python.keras.callbacks import CSVLogger +from tensorflow.python.keras.callbacks import EarlyStopping +from tensorflow.python.keras.callbacks import History +from tensorflow.python.keras.callbacks import LambdaCallback +from tensorflow.python.keras.callbacks import LearningRateScheduler +from tensorflow.python.keras.callbacks import ModelCheckpoint +from tensorflow.python.keras.callbacks import ProgbarLogger +from tensorflow.python.keras.callbacks import ReduceLROnPlateau +from tensorflow.python.keras.callbacks import RemoteMonitor +from tensorflow.python.keras.callbacks import TensorBoard +from tensorflow.python.keras.callbacks import TerminateOnNaN del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/constraints/__init__.py b/tensorflow/contrib/keras/api/keras/constraints/__init__.py index 152606d8ebbcadf57d971d508e15283da65e4aa3..08debf974ec3a36174c353ecaf9e425a9afc3f36 100644 --- a/tensorflow/contrib/keras/api/keras/constraints/__init__.py +++ b/tensorflow/contrib/keras/api/keras/constraints/__init__.py @@ -19,21 +19,21 @@ from __future__ import division from __future__ import print_function # Constraints functions / callable classes. -from tensorflow.python.keras._impl.keras.constraints import Constraint -from tensorflow.python.keras._impl.keras.constraints import max_norm -from tensorflow.python.keras._impl.keras.constraints import MaxNorm -from tensorflow.python.keras._impl.keras.constraints import min_max_norm -from tensorflow.python.keras._impl.keras.constraints import MinMaxNorm -from tensorflow.python.keras._impl.keras.constraints import non_neg -from tensorflow.python.keras._impl.keras.constraints import NonNeg -from tensorflow.python.keras._impl.keras.constraints import unit_norm -from tensorflow.python.keras._impl.keras.constraints import UnitNorm +from tensorflow.python.keras.constraints import Constraint +from tensorflow.python.keras.constraints import max_norm +from tensorflow.python.keras.constraints import MaxNorm +from tensorflow.python.keras.constraints import min_max_norm +from tensorflow.python.keras.constraints import MinMaxNorm +from tensorflow.python.keras.constraints import non_neg +from tensorflow.python.keras.constraints import NonNeg +from tensorflow.python.keras.constraints import unit_norm +from tensorflow.python.keras.constraints import UnitNorm # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.constraints import deserialize -from tensorflow.python.keras._impl.keras.constraints import serialize -from tensorflow.python.keras._impl.keras.constraints import get +from tensorflow.python.keras.constraints import deserialize +from tensorflow.python.keras.constraints import serialize +from tensorflow.python.keras.constraints import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py index b5371a03fd5f5755ba8844415276113c565f52db..a5a6fdab445d2d5328f203b6a704f89e9bb4ce67 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.boston_housing import load_data +from tensorflow.python.keras.datasets.boston_housing import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py index 68d3eb789ea2c410095c0c75e0b79a9b07d209a3..e74e5f347df2eeb626cd781c54c9a7b76561d4e9 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.cifar10 import load_data +from tensorflow.python.keras.datasets.cifar10 import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py index ca93742673341660ba69712feb59c5dd32ea3252..8f5753a6360dfbddb5678c4f2c02adff86b5f0cb 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.cifar100 import load_data +from tensorflow.python.keras.datasets.cifar100 import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py index 1c6396d2d32b88eaa900a5af4e62c7484fceab63..bd6ec4b8dfb0344ad0b89956939607ef51bb0889 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.imdb import get_word_index -from tensorflow.python.keras._impl.keras.datasets.imdb import load_data +from tensorflow.python.keras.datasets.imdb import get_word_index +from tensorflow.python.keras.datasets.imdb import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py index 364255f3387b59a419c010db9b93cdfbcba36186..f61145655bd5d98965e15fecd387d538e9bc642b 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.mnist import load_data +from tensorflow.python.keras.datasets.mnist import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py index bb6791a344ad0c372ac60cd4a332f5632841dd46..ade31f4ea9c33204a4350e6bc3a5a2469e54fd61 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.reuters import get_word_index -from tensorflow.python.keras._impl.keras.datasets.reuters import load_data +from tensorflow.python.keras.datasets.reuters import get_word_index +from tensorflow.python.keras.datasets.reuters import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/initializers/__init__.py b/tensorflow/contrib/keras/api/keras/initializers/__init__.py index 6b1fcfd2d9585d19ae3fd9705e128b19b1ec40e7..c6bdc4f0dac3f446238dc4cbc72fe4be278a5ff6 100644 --- a/tensorflow/contrib/keras/api/keras/initializers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/initializers/__init__.py @@ -19,30 +19,30 @@ from __future__ import division from __future__ import print_function # Initializer functions / callable classes. -from tensorflow.python.keras._impl.keras.initializers import Constant -from tensorflow.python.keras._impl.keras.initializers import Identity -from tensorflow.python.keras._impl.keras.initializers import Initializer -from tensorflow.python.keras._impl.keras.initializers import Ones -from tensorflow.python.keras._impl.keras.initializers import Orthogonal -from tensorflow.python.keras._impl.keras.initializers import RandomNormal -from tensorflow.python.keras._impl.keras.initializers import RandomUniform -from tensorflow.python.keras._impl.keras.initializers import TruncatedNormal -from tensorflow.python.keras._impl.keras.initializers import VarianceScaling -from tensorflow.python.keras._impl.keras.initializers import Zeros +from tensorflow.python.keras.initializers import Constant +from tensorflow.python.keras.initializers import Identity +from tensorflow.python.keras.initializers import Initializer +from tensorflow.python.keras.initializers import Ones +from tensorflow.python.keras.initializers import Orthogonal +from tensorflow.python.keras.initializers import RandomNormal +from tensorflow.python.keras.initializers import RandomUniform +from tensorflow.python.keras.initializers import TruncatedNormal +from tensorflow.python.keras.initializers import VarianceScaling +from tensorflow.python.keras.initializers import Zeros # Functional interface. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.initializers import glorot_normal -from tensorflow.python.keras._impl.keras.initializers import glorot_uniform -from tensorflow.python.keras._impl.keras.initializers import he_normal -from tensorflow.python.keras._impl.keras.initializers import he_uniform -from tensorflow.python.keras._impl.keras.initializers import lecun_normal -from tensorflow.python.keras._impl.keras.initializers import lecun_uniform +from tensorflow.python.keras.initializers import glorot_normal +from tensorflow.python.keras.initializers import glorot_uniform +from tensorflow.python.keras.initializers import he_normal +from tensorflow.python.keras.initializers import he_uniform +from tensorflow.python.keras.initializers import lecun_normal +from tensorflow.python.keras.initializers import lecun_uniform # Auxiliary utils. -from tensorflow.python.keras._impl.keras.initializers import deserialize -from tensorflow.python.keras._impl.keras.initializers import serialize -from tensorflow.python.keras._impl.keras.initializers import get +from tensorflow.python.keras.initializers import deserialize +from tensorflow.python.keras.initializers import serialize +from tensorflow.python.keras.initializers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/layers/__init__.py b/tensorflow/contrib/keras/api/keras/layers/__init__.py index acf0a5e1799b7c57dfd82861c9ccc1f132c34375..938c881fcbe18623fa18c21c112375f9914f887b 100644 --- a/tensorflow/contrib/keras/api/keras/layers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/layers/__init__.py @@ -20,128 +20,128 @@ from __future__ import print_function # Generic layers. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.engine import Input -from tensorflow.python.keras._impl.keras.engine import InputLayer -from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine import Layer +from tensorflow.python.keras.engine import Input +from tensorflow.python.keras.engine import InputLayer +from tensorflow.python.keras.engine import InputSpec +from tensorflow.python.keras.engine import Layer # Advanced activations. -from tensorflow.python.keras._impl.keras.layers.advanced_activations import LeakyReLU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import PReLU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import ELU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import ThresholdedReLU +from tensorflow.python.keras.layers.advanced_activations import LeakyReLU +from tensorflow.python.keras.layers.advanced_activations import PReLU +from tensorflow.python.keras.layers.advanced_activations import ELU +from tensorflow.python.keras.layers.advanced_activations import ThresholdedReLU # Convolution layers. -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConv2D +from tensorflow.python.keras.layers.convolutional import Conv1D +from tensorflow.python.keras.layers.convolutional import Conv2D +from tensorflow.python.keras.layers.convolutional import Conv3D +from tensorflow.python.keras.layers.convolutional import Conv2DTranspose +from tensorflow.python.keras.layers.convolutional import Conv3DTranspose +from tensorflow.python.keras.layers.convolutional import SeparableConv2D # Convolution layer aliases. -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution2D +from tensorflow.python.keras.layers.convolutional import Convolution1D +from tensorflow.python.keras.layers.convolutional import Convolution2D +from tensorflow.python.keras.layers.convolutional import Convolution3D +from tensorflow.python.keras.layers.convolutional import Convolution2DTranspose +from tensorflow.python.keras.layers.convolutional import Convolution3DTranspose +from tensorflow.python.keras.layers.convolutional import SeparableConvolution2D # Image processing layers. -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling1D -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling2D -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling3D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding1D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding2D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping3D +from tensorflow.python.keras.layers.convolutional import UpSampling1D +from tensorflow.python.keras.layers.convolutional import UpSampling2D +from tensorflow.python.keras.layers.convolutional import UpSampling3D +from tensorflow.python.keras.layers.convolutional import ZeroPadding1D +from tensorflow.python.keras.layers.convolutional import ZeroPadding2D +from tensorflow.python.keras.layers.convolutional import ZeroPadding3D +from tensorflow.python.keras.layers.convolutional import Cropping1D +from tensorflow.python.keras.layers.convolutional import Cropping2D +from tensorflow.python.keras.layers.convolutional import Cropping3D # Convolutional-recurrent layers. -from tensorflow.python.keras._impl.keras.layers.convolutional_recurrent import ConvLSTM2D +from tensorflow.python.keras.layers.convolutional_recurrent import ConvLSTM2D # Core layers. -from tensorflow.python.keras._impl.keras.layers.core import Masking -from tensorflow.python.keras._impl.keras.layers.core import Dropout -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout1D -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout2D -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout3D -from tensorflow.python.keras._impl.keras.layers.core import Activation -from tensorflow.python.keras._impl.keras.layers.core import Reshape -from tensorflow.python.keras._impl.keras.layers.core import Permute -from tensorflow.python.keras._impl.keras.layers.core import Flatten -from tensorflow.python.keras._impl.keras.layers.core import RepeatVector -from tensorflow.python.keras._impl.keras.layers.core import Lambda -from tensorflow.python.keras._impl.keras.layers.core import Dense -from tensorflow.python.keras._impl.keras.layers.core import ActivityRegularization +from tensorflow.python.keras.layers.core import Masking +from tensorflow.python.keras.layers.core import Dropout +from tensorflow.python.keras.layers.core import SpatialDropout1D +from tensorflow.python.keras.layers.core import SpatialDropout2D +from tensorflow.python.keras.layers.core import SpatialDropout3D +from tensorflow.python.keras.layers.core import Activation +from tensorflow.python.keras.layers.core import Reshape +from tensorflow.python.keras.layers.core import Permute +from tensorflow.python.keras.layers.core import Flatten +from tensorflow.python.keras.layers.core import RepeatVector +from tensorflow.python.keras.layers.core import Lambda +from tensorflow.python.keras.layers.core import Dense +from tensorflow.python.keras.layers.core import ActivityRegularization # Embedding layers. -from tensorflow.python.keras._impl.keras.layers.embeddings import Embedding +from tensorflow.python.keras.layers.embeddings import Embedding # Locally-connected layers. -from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected1D -from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected2D +from tensorflow.python.keras.layers.local import LocallyConnected1D +from tensorflow.python.keras.layers.local import LocallyConnected2D # Merge layers. -from tensorflow.python.keras._impl.keras.layers.merge import Add -from tensorflow.python.keras._impl.keras.layers.merge import Multiply -from tensorflow.python.keras._impl.keras.layers.merge import Average -from tensorflow.python.keras._impl.keras.layers.merge import Maximum -from tensorflow.python.keras._impl.keras.layers.merge import Concatenate -from tensorflow.python.keras._impl.keras.layers.merge import Dot -from tensorflow.python.keras._impl.keras.layers.merge import add -from tensorflow.python.keras._impl.keras.layers.merge import multiply -from tensorflow.python.keras._impl.keras.layers.merge import average -from tensorflow.python.keras._impl.keras.layers.merge import maximum -from tensorflow.python.keras._impl.keras.layers.merge import concatenate -from tensorflow.python.keras._impl.keras.layers.merge import dot +from tensorflow.python.keras.layers.merge import Add +from tensorflow.python.keras.layers.merge import Multiply +from tensorflow.python.keras.layers.merge import Average +from tensorflow.python.keras.layers.merge import Maximum +from tensorflow.python.keras.layers.merge import Concatenate +from tensorflow.python.keras.layers.merge import Dot +from tensorflow.python.keras.layers.merge import add +from tensorflow.python.keras.layers.merge import multiply +from tensorflow.python.keras.layers.merge import average +from tensorflow.python.keras.layers.merge import maximum +from tensorflow.python.keras.layers.merge import concatenate +from tensorflow.python.keras.layers.merge import dot # Noise layers. -from tensorflow.python.keras._impl.keras.layers.noise import AlphaDropout -from tensorflow.python.keras._impl.keras.layers.noise import GaussianNoise -from tensorflow.python.keras._impl.keras.layers.noise import GaussianDropout +from tensorflow.python.keras.layers.noise import AlphaDropout +from tensorflow.python.keras.layers.noise import GaussianNoise +from tensorflow.python.keras.layers.noise import GaussianDropout # Normalization layers. -from tensorflow.python.keras._impl.keras.layers.normalization import BatchNormalization +from tensorflow.python.keras.layers.normalization import BatchNormalization # Pooling layers. -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling3D +from tensorflow.python.keras.layers.pooling import MaxPooling1D +from tensorflow.python.keras.layers.pooling import MaxPooling2D +from tensorflow.python.keras.layers.pooling import MaxPooling3D +from tensorflow.python.keras.layers.pooling import AveragePooling1D +from tensorflow.python.keras.layers.pooling import AveragePooling2D +from tensorflow.python.keras.layers.pooling import AveragePooling3D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling1D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling2D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling3D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling1D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling2D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling3D # Pooling layer aliases. -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool3D +from tensorflow.python.keras.layers.pooling import MaxPool1D +from tensorflow.python.keras.layers.pooling import MaxPool2D +from tensorflow.python.keras.layers.pooling import MaxPool3D +from tensorflow.python.keras.layers.pooling import AvgPool1D +from tensorflow.python.keras.layers.pooling import AvgPool2D +from tensorflow.python.keras.layers.pooling import AvgPool3D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool1D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool2D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool3D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool1D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool2D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool3D # Recurrent layers. -from tensorflow.python.keras._impl.keras.layers.recurrent import SimpleRNN -from tensorflow.python.keras._impl.keras.layers.recurrent import GRU -from tensorflow.python.keras._impl.keras.layers.recurrent import LSTM +from tensorflow.python.keras.layers.recurrent import SimpleRNN +from tensorflow.python.keras.layers.recurrent import GRU +from tensorflow.python.keras.layers.recurrent import LSTM # Wrapper functions -from tensorflow.python.keras._impl.keras.layers.wrappers import Wrapper -from tensorflow.python.keras._impl.keras.layers.wrappers import Bidirectional -from tensorflow.python.keras._impl.keras.layers.wrappers import TimeDistributed +from tensorflow.python.keras.layers.wrappers import Wrapper +from tensorflow.python.keras.layers.wrappers import Bidirectional +from tensorflow.python.keras.layers.wrappers import TimeDistributed del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/losses/__init__.py b/tensorflow/contrib/keras/api/keras/losses/__init__.py index 66721b694f5fd5fae7ca521ff56d4c6c6bce79b5..c4476a7bbd5056fa898468a46031bf3d8b1e44cf 100644 --- a/tensorflow/contrib/keras/api/keras/losses/__init__.py +++ b/tensorflow/contrib/keras/api/keras/losses/__init__.py @@ -19,26 +19,26 @@ from __future__ import division from __future__ import print_function # Loss functions. -from tensorflow.python.keras._impl.keras.losses import binary_crossentropy -from tensorflow.python.keras._impl.keras.losses import categorical_crossentropy -from tensorflow.python.keras._impl.keras.losses import categorical_hinge -from tensorflow.python.keras._impl.keras.losses import cosine_proximity -from tensorflow.python.keras._impl.keras.losses import hinge -from tensorflow.python.keras._impl.keras.losses import kullback_leibler_divergence -from tensorflow.python.keras._impl.keras.losses import logcosh -from tensorflow.python.keras._impl.keras.losses import mean_absolute_error -from tensorflow.python.keras._impl.keras.losses import mean_absolute_percentage_error -from tensorflow.python.keras._impl.keras.losses import mean_squared_error -from tensorflow.python.keras._impl.keras.losses import mean_squared_logarithmic_error -from tensorflow.python.keras._impl.keras.losses import poisson -from tensorflow.python.keras._impl.keras.losses import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.losses import squared_hinge +from tensorflow.python.keras.losses import binary_crossentropy +from tensorflow.python.keras.losses import categorical_crossentropy +from tensorflow.python.keras.losses import categorical_hinge +from tensorflow.python.keras.losses import cosine_proximity +from tensorflow.python.keras.losses import hinge +from tensorflow.python.keras.losses import kullback_leibler_divergence +from tensorflow.python.keras.losses import logcosh +from tensorflow.python.keras.losses import mean_absolute_error +from tensorflow.python.keras.losses import mean_absolute_percentage_error +from tensorflow.python.keras.losses import mean_squared_error +from tensorflow.python.keras.losses import mean_squared_logarithmic_error +from tensorflow.python.keras.losses import poisson +from tensorflow.python.keras.losses import sparse_categorical_crossentropy +from tensorflow.python.keras.losses import squared_hinge # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.losses import deserialize -from tensorflow.python.keras._impl.keras.losses import serialize -from tensorflow.python.keras._impl.keras.losses import get +from tensorflow.python.keras.losses import deserialize +from tensorflow.python.keras.losses import serialize +from tensorflow.python.keras.losses import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/metrics/__init__.py b/tensorflow/contrib/keras/api/keras/metrics/__init__.py index 59faf037bce0f087d244a2faaeb52713bdc3b772..7317fdb52c5b79e787a49d71be49f5261d6b1fff 100644 --- a/tensorflow/contrib/keras/api/keras/metrics/__init__.py +++ b/tensorflow/contrib/keras/api/keras/metrics/__init__.py @@ -19,28 +19,28 @@ from __future__ import division from __future__ import print_function # Metrics functions. -from tensorflow.python.keras._impl.keras.metrics import binary_accuracy -from tensorflow.python.keras._impl.keras.metrics import binary_crossentropy -from tensorflow.python.keras._impl.keras.metrics import categorical_accuracy -from tensorflow.python.keras._impl.keras.metrics import categorical_crossentropy -from tensorflow.python.keras._impl.keras.metrics import cosine_proximity -from tensorflow.python.keras._impl.keras.metrics import hinge -from tensorflow.python.keras._impl.keras.metrics import kullback_leibler_divergence -from tensorflow.python.keras._impl.keras.metrics import mean_absolute_error -from tensorflow.python.keras._impl.keras.metrics import mean_absolute_percentage_error -from tensorflow.python.keras._impl.keras.metrics import mean_squared_error -from tensorflow.python.keras._impl.keras.metrics import mean_squared_logarithmic_error -from tensorflow.python.keras._impl.keras.metrics import poisson -from tensorflow.python.keras._impl.keras.metrics import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.metrics import sparse_top_k_categorical_accuracy -from tensorflow.python.keras._impl.keras.metrics import squared_hinge -from tensorflow.python.keras._impl.keras.metrics import top_k_categorical_accuracy +from tensorflow.python.keras.metrics import binary_accuracy +from tensorflow.python.keras.metrics import binary_crossentropy +from tensorflow.python.keras.metrics import categorical_accuracy +from tensorflow.python.keras.metrics import categorical_crossentropy +from tensorflow.python.keras.metrics import cosine_proximity +from tensorflow.python.keras.metrics import hinge +from tensorflow.python.keras.metrics import kullback_leibler_divergence +from tensorflow.python.keras.metrics import mean_absolute_error +from tensorflow.python.keras.metrics import mean_absolute_percentage_error +from tensorflow.python.keras.metrics import mean_squared_error +from tensorflow.python.keras.metrics import mean_squared_logarithmic_error +from tensorflow.python.keras.metrics import poisson +from tensorflow.python.keras.metrics import sparse_categorical_crossentropy +from tensorflow.python.keras.metrics import sparse_top_k_categorical_accuracy +from tensorflow.python.keras.metrics import squared_hinge +from tensorflow.python.keras.metrics import top_k_categorical_accuracy # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.metrics import deserialize -from tensorflow.python.keras._impl.keras.metrics import serialize -from tensorflow.python.keras._impl.keras.metrics import get +from tensorflow.python.keras.metrics import deserialize +from tensorflow.python.keras.metrics import serialize +from tensorflow.python.keras.metrics import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/models/__init__.py b/tensorflow/contrib/keras/api/keras/models/__init__.py index 2fb4ac0960d38f28a1c9c897a0f1aedf57e048ac..3a196984cd88cb60fbc2a9db306ce8fecf0febc0 100644 --- a/tensorflow/contrib/keras/api/keras/models/__init__.py +++ b/tensorflow/contrib/keras/api/keras/models/__init__.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.models import load_model -from tensorflow.python.keras._impl.keras.models import Model -from tensorflow.python.keras._impl.keras.models import model_from_config -from tensorflow.python.keras._impl.keras.models import model_from_json -from tensorflow.python.keras._impl.keras.models import model_from_yaml -from tensorflow.python.keras._impl.keras.models import save_model -from tensorflow.python.keras._impl.keras.models import Sequential +from tensorflow.python.keras.models import load_model +from tensorflow.python.keras.models import Model +from tensorflow.python.keras.models import model_from_config +from tensorflow.python.keras.models import model_from_json +from tensorflow.python.keras.models import model_from_yaml +from tensorflow.python.keras.models import save_model +from tensorflow.python.keras.models import Sequential del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/optimizers/__init__.py b/tensorflow/contrib/keras/api/keras/optimizers/__init__.py index 44f47bc47f4a0e31aaf2ac8f67cfdbef410d8c44..4849a06747958ab41b8b6309fa848aff3da3f633 100644 --- a/tensorflow/contrib/keras/api/keras/optimizers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/optimizers/__init__.py @@ -19,20 +19,20 @@ from __future__ import division from __future__ import print_function # Optimizer classes. -from tensorflow.python.keras._impl.keras.optimizers import Adadelta -from tensorflow.python.keras._impl.keras.optimizers import Adagrad -from tensorflow.python.keras._impl.keras.optimizers import Adam -from tensorflow.python.keras._impl.keras.optimizers import Adamax -from tensorflow.python.keras._impl.keras.optimizers import Nadam -from tensorflow.python.keras._impl.keras.optimizers import Optimizer -from tensorflow.python.keras._impl.keras.optimizers import RMSprop -from tensorflow.python.keras._impl.keras.optimizers import SGD +from tensorflow.python.keras.optimizers import Adadelta +from tensorflow.python.keras.optimizers import Adagrad +from tensorflow.python.keras.optimizers import Adam +from tensorflow.python.keras.optimizers import Adamax +from tensorflow.python.keras.optimizers import Nadam +from tensorflow.python.keras.optimizers import Optimizer +from tensorflow.python.keras.optimizers import RMSprop +from tensorflow.python.keras.optimizers import SGD # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.optimizers import deserialize -from tensorflow.python.keras._impl.keras.optimizers import serialize -from tensorflow.python.keras._impl.keras.optimizers import get +from tensorflow.python.keras.optimizers import deserialize +from tensorflow.python.keras.optimizers import serialize +from tensorflow.python.keras.optimizers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py index b96e7675527041d3952b049f5f431d3df36eea4c..1f9e82b41bf09b235e93fa512a50ea4c3047c01b 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py @@ -18,20 +18,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing.image import apply_transform -from tensorflow.python.keras._impl.keras.preprocessing.image import array_to_img -from tensorflow.python.keras._impl.keras.preprocessing.image import DirectoryIterator -from tensorflow.python.keras._impl.keras.preprocessing.image import flip_axis -from tensorflow.python.keras._impl.keras.preprocessing.image import ImageDataGenerator -from tensorflow.python.keras._impl.keras.preprocessing.image import img_to_array -from tensorflow.python.keras._impl.keras.preprocessing.image import Iterator -from tensorflow.python.keras._impl.keras.preprocessing.image import load_img -from tensorflow.python.keras._impl.keras.preprocessing.image import NumpyArrayIterator -from tensorflow.python.keras._impl.keras.preprocessing.image import random_channel_shift -from tensorflow.python.keras._impl.keras.preprocessing.image import random_rotation -from tensorflow.python.keras._impl.keras.preprocessing.image import random_shear -from tensorflow.python.keras._impl.keras.preprocessing.image import random_shift -from tensorflow.python.keras._impl.keras.preprocessing.image import random_zoom +from tensorflow.python.keras.preprocessing.image import apply_transform +from tensorflow.python.keras.preprocessing.image import array_to_img +from tensorflow.python.keras.preprocessing.image import DirectoryIterator +from tensorflow.python.keras.preprocessing.image import flip_axis +from tensorflow.python.keras.preprocessing.image import ImageDataGenerator +from tensorflow.python.keras.preprocessing.image import img_to_array +from tensorflow.python.keras.preprocessing.image import Iterator +from tensorflow.python.keras.preprocessing.image import load_img +from tensorflow.python.keras.preprocessing.image import NumpyArrayIterator +from tensorflow.python.keras.preprocessing.image import random_channel_shift +from tensorflow.python.keras.preprocessing.image import random_rotation +from tensorflow.python.keras.preprocessing.image import random_shear +from tensorflow.python.keras.preprocessing.image import random_shift +from tensorflow.python.keras.preprocessing.image import random_zoom del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py index 112f6af5e588bcb2e85fdbecea86f402742d44e7..9a93b6fb57ff5aaab25f2b606249a6022814b5e4 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing.sequence import make_sampling_table -from tensorflow.python.keras._impl.keras.preprocessing.sequence import pad_sequences -from tensorflow.python.keras._impl.keras.preprocessing.sequence import skipgrams +from tensorflow.python.keras.preprocessing.sequence import make_sampling_table +from tensorflow.python.keras.preprocessing.sequence import pad_sequences +from tensorflow.python.keras.preprocessing.sequence import skipgrams del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py index 5bf1a2fb21dc27f7aa10cd08b1496e3991c61d2f..86386a9b6762d1c5cb3915ace64686cc25367e0f 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing.text import one_hot -from tensorflow.python.keras._impl.keras.preprocessing.text import text_to_word_sequence -from tensorflow.python.keras._impl.keras.preprocessing.text import Tokenizer +from tensorflow.python.keras.preprocessing.text import one_hot +from tensorflow.python.keras.preprocessing.text import text_to_word_sequence +from tensorflow.python.keras.preprocessing.text import Tokenizer del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/regularizers/__init__.py b/tensorflow/contrib/keras/api/keras/regularizers/__init__.py index 3e707ccab577b5e28febd83d91f84d7b1c0d5d82..d668e39c09ca28239e56763f111fb01939bedc69 100644 --- a/tensorflow/contrib/keras/api/keras/regularizers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/regularizers/__init__.py @@ -19,19 +19,19 @@ from __future__ import division from __future__ import print_function # Regularizer functions / callable classes. -from tensorflow.python.keras._impl.keras.regularizers import L1L2 -from tensorflow.python.keras._impl.keras.regularizers import Regularizer +from tensorflow.python.keras.regularizers import L1L2 +from tensorflow.python.keras.regularizers import Regularizer # Functional interface. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.regularizers import l1 -from tensorflow.python.keras._impl.keras.regularizers import l2 -from tensorflow.python.keras._impl.keras.regularizers import l1_l2 +from tensorflow.python.keras.regularizers import l1 +from tensorflow.python.keras.regularizers import l2 +from tensorflow.python.keras.regularizers import l1_l2 # Auxiliary utils. -from tensorflow.python.keras._impl.keras.regularizers import deserialize -from tensorflow.python.keras._impl.keras.regularizers import serialize -from tensorflow.python.keras._impl.keras.regularizers import get +from tensorflow.python.keras.regularizers import deserialize +from tensorflow.python.keras.regularizers import serialize +from tensorflow.python.keras.regularizers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/utils/__init__.py b/tensorflow/contrib/keras/api/keras/utils/__init__.py index a7c2179fe7ad434356921a5fb8709aa5b1f33498..47cd01b924fb43e8a83836c58f8ced61e9e88268 100644 --- a/tensorflow/contrib/keras/api/keras/utils/__init__.py +++ b/tensorflow/contrib/keras/api/keras/utils/__init__.py @@ -18,21 +18,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file -from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence -from tensorflow.python.keras._impl.keras.utils.data_utils import SequenceEnqueuer -from tensorflow.python.keras._impl.keras.utils.generic_utils import custom_object_scope -from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope -from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.python.keras._impl.keras.utils.generic_utils import get_custom_objects -from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar -from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object -from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix -from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model -from tensorflow.python.keras._impl.keras.utils.np_utils import normalize -from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical -from tensorflow.python.keras._impl.keras.utils.vis_utils import plot_model +from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer +from tensorflow.python.keras.utils.data_utils import get_file +from tensorflow.python.keras.utils.data_utils import Sequence +from tensorflow.python.keras.utils.data_utils import SequenceEnqueuer +from tensorflow.python.keras.utils.generic_utils import custom_object_scope +from tensorflow.python.keras.utils.generic_utils import CustomObjectScope +from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras.utils.generic_utils import get_custom_objects +from tensorflow.python.keras.utils.generic_utils import Progbar +from tensorflow.python.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras.utils.io_utils import HDF5Matrix +from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model +from tensorflow.python.keras.utils.np_utils import normalize +from tensorflow.python.keras.utils.np_utils import to_categorical +from tensorflow.python.keras.utils.vis_utils import plot_model del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py b/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py index a46f859273ea0117e29a403057f9f81bc758dd52..c4b7aa765c26bafbfcfe45df02e58d1cf1064b4b 100644 --- a/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py +++ b/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasClassifier -from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasRegressor +from tensorflow.python.keras.wrappers.scikit_learn import KerasClassifier +from tensorflow.python.keras.wrappers.scikit_learn import KerasRegressor del absolute_import del division diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index d5b3b279a1b7327602790c0260349cb0c758aa86..7355a403aeef78cc7e76d58adfe114e4729f6595 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -381,7 +381,7 @@ py_test( py_test( name = "rev_block_lib_test", - size = "small", + size = "medium", srcs = ["python/layers/rev_block_lib_test.py"], srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py index 49c3faf3b7f5eaa3b1542a1fdddcfaff99737a24..60e1d85ea9c08a51763fdaf08853f8d9b67347e5 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py @@ -458,7 +458,7 @@ def scattered_embedding_lookup_sparse(params, return embeddings -def embedding_lookup_unique(params, ids, name=None): +def embedding_lookup_unique(params, ids, partition_strategy="mod", name=None): """Version of embedding_lookup that avoids duplicate lookups. This can save communication in the case of repeated ids. @@ -470,6 +470,9 @@ def embedding_lookup_unique(params, ids, name=None): `PartitionedVariable`. Shape `[index, d1, d2, ...]`. ids: A one-dimensional `Tensor` with type `int32` or `int64` containing the ids to be looked up in `params`. Shape `[ids1, ids2, ...]`. + partition_strategy: A string specifying the partitioning strategy, relevant + if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default + is `"mod"`. name: A name for this operation (optional). Returns: @@ -485,7 +488,8 @@ def embedding_lookup_unique(params, ids, name=None): ids_flat = array_ops.reshape( ids, math_ops.reduce_prod(shape, keepdims=True)) unique_ids, idx = array_ops.unique(ids_flat) - unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids) + unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids, + partition_strategy) embeds_flat = array_ops.gather(unique_embeddings, idx) embed_shape = array_ops.concat( [shape, array_ops.shape(unique_embeddings)[1:]], 0) diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index b01fd5d5c95ac15c76f9dbe7c77f7e76f12149a9..56e9194cebbe46907707f7ac0996f9a56fb53c0f 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1333,7 +1333,7 @@ class DropoutTest(test.TestCase): with self.test_session(): images = np.random.uniform(size=(5, height, width, 3)) output = _layers.dropout(images) - self.assertEqual(output.op.name, 'Dropout/dropout/mul') + self.assertEqual(output.op.name, 'Dropout/dropout_1/mul') output.get_shape().assert_is_compatible_with( ops.convert_to_tensor(images).get_shape()) diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 3b053cd4c66952cf6c494186b16c17f38801bcaf..0fdbe8f6308e30db2043c400f37d7dcb6058d1f2 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -434,6 +434,7 @@ py_test( name = "kmeans_test", size = "medium", srcs = ["python/learn/estimators/kmeans_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = [ "noasan", # b/73741358 @@ -485,6 +486,7 @@ py_test( name = "state_saving_rnn_estimator_test", size = "medium", srcs = ["python/learn/estimators/state_saving_rnn_estimator_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = ["noasan"], deps = [ @@ -744,7 +746,7 @@ py_test( tf_py_test( name = "graph_io_test", - size = "small", + size = "medium", srcs = ["python/learn/learn_io/graph_io_test.py"], additional_deps = [ ":learn", diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index e28e6854a5097d66cb486be3e82f3726f5cc70fd..339c4e0e360ed9ef9906f0e51b64a0dc13826259 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -1862,12 +1862,12 @@ def _get_arguments(func): if hasattr(func, "__code__"): # Regular function. return tf_inspect.getargspec(func) - elif hasattr(func, "__call__"): - # Callable object. - return _get_arguments(func.__call__) elif hasattr(func, "func"): # Partial function. return _get_arguments(func.func) + elif hasattr(func, "__call__"): + # Callable object. + return _get_arguments(func.__call__) def _verify_loss_fn_args(loss_fn): diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 3744abd860e7f460133873eb534fd75887182f78..541da9061732ad271f6d5456446a9c30b81e58dd 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -38,19 +38,19 @@ from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.tpu.python.tpu import tpu_estimator from tensorflow.python.estimator import estimator as core_estimator -from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import saver from tensorflow.python.training import server_lib from tensorflow.python.util import compat +from tensorflow.python.util import function_utils __all__ = ["Experiment"] def _get_standardized_predicate_fn(predicate_fn): - pred_fn_args = estimator_util.fn_args(predicate_fn) + pred_fn_args = function_utils.fn_args(predicate_fn) if "checkpoint_path" not in pred_fn_args: # pylint: disable=unused-argument def _pred_fn_wrapper(eval_results, checkpoint_path): @@ -468,10 +468,15 @@ class Experiment(object): on which that evaluation was based. At the beginning of evaluation, the passed `eval_results` will be None so it's expected that the predicate function handles that gracefully. - When `predicate_fn` is not specified, continuous eval will run in an - infinite loop (if `train_steps` is None). or exit once global step - reaches `train_steps`. - + Continuous eval behavior under different conditions: + * When `predicate_fn` is specified: + + if `train_steps` is None, run until `predicate_fn` returns False. + + if `train_steps` is specified, run until either global step + reaches `train_steps` or `predicate_fn` returns False. + * When `predicate_fn` is not specified: + + if `train_steps` is None, run in an infinite loop. + + if `train_steps` is specified, run until global step reaches + `train_steps`. export: Whether to export from this step. Default is 'True'. Raises: diff --git a/tensorflow/contrib/lite/Android.bp b/tensorflow/contrib/lite/Android.bp index 8301f9263693eb8254ae8351d3177f9d6165bb0b..bd470696c5879822f9a75b77b3ec132500cd0d34 100644 --- a/tensorflow/contrib/lite/Android.bp +++ b/tensorflow/contrib/lite/Android.bp @@ -45,6 +45,7 @@ cc_library_static { "graph_info.cc", "interpreter.cc", "model.cc", + "op_resolver.cc", "nnapi_delegate.cc", "optional_debug_tools.cc", "simple_memory_arena.cc", diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 10065e894c48d48b8b7136895c55599c8854e03b..55b984f260ec49ab9b52be6402885a46226cba70 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -6,8 +6,6 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") -exports_files(["LICENSE"]) - exports_files(glob([ "testdata/*.bin", "testdata/*.pb", @@ -114,6 +112,7 @@ cc_library( "interpreter.cc", "model.cc", "nnapi_delegate.cc", + "op_resolver.cc", "optional_debug_tools.cc", ], hdrs = [ @@ -124,6 +123,7 @@ cc_library( "interpreter.h", "model.h", "nnapi_delegate.h", + "op_resolver.h", "optional_debug_tools.h", ], copts = tflite_copts(), @@ -226,6 +226,18 @@ cc_test( ], ) +# Test OpResolver. +cc_test( + name = "op_resolver_test", + size = "small", + srcs = ["op_resolver_test.cc"], + deps = [ + ":framework", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + # Test the C extension API code. cc_test( name = "context_test", diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile index e4f86e258afe3df9ba149c82066b6d145f332488..cc8a8035d1dadeec98886ba1dae4cdf403f26de4 100644 --- a/tensorflow/contrib/lite/Makefile +++ b/tensorflow/contrib/lite/Makefile @@ -29,7 +29,7 @@ GENDIR := $(MAKEFILE_DIR)/gen/obj/ CXX := $(CC_PREFIX)gcc CXXFLAGS := --std=c++11 -O3 -DNDEBUG CC := $(CC_PREFIX)gcc -CFLAGS := -O3 -DNDEBUG +CCFLAGS := -O3 -DNDEBUG LDOPTS := LDOPTS += -L/usr/local/lib ARFLAGS := -r diff --git a/tensorflow/contrib/lite/RELEASE.md b/tensorflow/contrib/lite/RELEASE.md new file mode 100644 index 0000000000000000000000000000000000000000..8fd63d5cee7db38fadf63ab8530bef7a3d99dd0d --- /dev/null +++ b/tensorflow/contrib/lite/RELEASE.md @@ -0,0 +1,8 @@ +# Release 0.1.7 + +* TensorFlow Lite 0.1.7 is based on tag `tflite-v0.1.7` (git commit + fa1db5eb0da85b5baccc2a46d534fdeb3bb473d0). +* To reproduce the iOS library, it's required to cherry pick git commit + f1f1d5172fe5bfeaeb2cf657ffc43ba744187bee to fix a dependency issue. +* The code is based on TensorFlow 1.8.0 release candidate and it's very close + to TensorFlow 1.8.0 release. diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 85216776823eab2ab3ac2a3bc666f21e312acc6c..9bfc0a0fbeff38fb77b6d67c1a2df37a6807528c 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -1,4 +1,8 @@ """Generate Flatbuffer binary from json.""" +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) def tflite_copts(): """Defines compile time flags.""" @@ -185,32 +189,102 @@ def json_to_tflite(name, src, out): tools = [flatc], ) -def gen_zipped_test_files(name, files): +# This is the master list of generated examples that will be made into tests. A +# function called make_XXX_tests() must also appear in generate_examples.py. +# Disable a test by commenting it out. If you do, add a link to a bug or issue. +def generated_test_models(): + return [ + "add", + "arg_max", + "avg_pool", + "batch_to_space_nd", + "concat", + "constant", + "control_dep", + "conv", + "depthwiseconv", + "div", + "exp", + "floor", + "fully_connected", + "fused_batch_norm", + "gather", + "global_batch_norm", + "greater", + "greater_equal", + "l2_pool", + "l2norm", + "less", + "less_equal", + "local_response_norm", + "log_softmax", + "max_pool", + "maximum", + "mean", + "minimum", + "mul", + "neg", + "pad", + "padv2", + # "prelu", + "relu", + "relu1", + "relu6", + "reshape", + "resize_bilinear", + "sigmoid", + "sin", + "slice", + "softmax", + "space_to_batch_nd", + "space_to_depth", + "split", + "squeeze", + "strided_slice", + "strided_slice_1d_exhaustive", + "sub", + "topk", + "transpose", + "transpose_conv", + "where", + ] + +def gen_zip_test(name, test_name, **kwargs): + """Generate a zipped-example test and its dependent zip files. + + Args: + name: Resulting cc_test target name + test_name: Test targets this model. Comes from the list above. + **kwargs: tf_cc_test kwargs. + """ + gen_zipped_test_file( + name = "zip_%s" % test_name, + file = "%s.zip" % test_name, + ) + tf_cc_test(name, **kwargs) + +def gen_zipped_test_file(name, file): """Generate a zip file of tests by using :generate_examples. Args: - name: Name of output. We will produce "`name`_files" as a target. - files: A list of zip file basenames. + name: Name of output. We will produce "`file`.files" as a target. + file: The name of one of the generated_examples targets, e.g. "transpose" """ toco = "//tensorflow/contrib/lite/toco:toco" - out_files = [] - for f in files: - out_file = name + "/" + f - out_files.append(out_file) - native.genrule( - name = name + "_" + f + ".files", - cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco - + " --zip_to_output " + f + " $(@D)"), - outs = [out_file], - tools = [ - ":generate_examples", - toco, - ], - ) + native.genrule( + name = file + ".files", + cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco + + " --zip_to_output " + file + " $(@D)"), + outs = [file], + tools = [ + ":generate_examples", + toco, + ], + ) native.filegroup( name = name, - srcs = out_files, + srcs = [file], ) def gen_selected_ops(name, model): diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 4910c89eaebabb7bd9a4e003b75fa6de4d5af69d..8660c653ae4c0c69e4f5ad8fae739c8c1db7414c 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -161,6 +161,9 @@ typedef struct { typedef struct { } TfLitePadParams; +typedef struct { +} TfLitePadV2Params; + typedef struct { // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. // For now we will fix the maximum possible number of dimensions. @@ -227,6 +230,12 @@ typedef struct { TfLiteType output_type; } TfLiteArgMaxParams; +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; +} TfLiteTransposeConvParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index 962a7a8970703268b1860875b449a0ff85f449e0..7e285186f45a61a451fd7328b061e16059049ea5 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -85,6 +85,14 @@ typedef enum { kTfLiteBuiltinMinimum = 57, kTfLiteBuiltinLess = 58, kTfLiteBuiltinNeg = 59, + kTfLiteBuiltinPadv2 = 60, + kTfLiteBuiltinGreater = 61, + kTfLiteBuiltinGreaterEqual = 62, + kTfLiteBuiltinLessEqual = 63, + kTfLiteBuiltinSelect = 64, + kTfLiteBuiltinSlice = 65, + kTfLiteBuiltinSin = 66, + kTfLiteBuiltinTransposeConv = 67, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index 12841d233cc1d3c5e1219fc505b1975d2a7fa3e3..4eb66cc225eb04923be9aaa445a335ad822c8a6f 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -370,13 +370,21 @@ typedef struct _TfLiteRegistration { // Builtin codes. If this kernel refers to a builtin this is the code // of the builtin. This is so we can do marshaling to other frameworks like - // NN API. Note, it is the responsibility of the registration binder to - // set this properly. + // NN API. + // Note: It is the responsibility of the registration binder to set this + // properly. int32_t builtin_code; // Custom op name. If the op is a builtin, this will be null. + // Note: It is the responsibility of the registration binder to set this + // properly. // WARNING: This is an experimental interface that is subject to change. const char* custom_name; + + // The version of the op. + // Note: It is the responsibility of the registration binder to set this + // properly. + int version; } TfLiteRegistration; // WARNING: This is an experimental interface that is subject to change. diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h index 2a64c1de725b601e9b6e9325d9faacb37df0e626..e36218e4f12057a362af47c48454f7930fc495f2 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h @@ -62,8 +62,8 @@ void resize(T* out, uint8_t* in, int image_height, int image_width, {1, wanted_height, wanted_width, wanted_channels}, quant); ops::builtin::BuiltinOpResolver resolver; - TfLiteRegistration* resize_op = - resolver.FindOp(BuiltinOperator_RESIZE_BILINEAR); + const TfLiteRegistration* resize_op = + resolver.FindOp(BuiltinOperator_RESIZE_BILINEAR, 1); auto* params = reinterpret_cast( malloc(sizeof(TfLiteResizeBilinearParams))); params->align_corners = false; diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc index 456c5c6dc782f4e21a5062e353635117a39cacb9..966fcd2a31fd4d4ff2c3e91633550a8effa81ee8 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc @@ -77,14 +77,13 @@ void PrintProfilingInfo(const profiling::ProfileEvent* e, uint32_t op_index, // time (ms) , Node xxx, OpCode xxx, symblic name // 5.352, Node 5, OpCode 4, DEPTHWISE_CONV_2D - LOG(INFO) << std::fixed << std::setw(10) << std::setprecision(3) << (e->end_timestamp_us - e->begin_timestamp_us) / 1000.0 << ", Node " << std::setw(3) << std::setprecision(3) << op_index << ", OpCode " << std::setw(3) << std::setprecision(3) << registration.builtin_code << ", " << EnumNameBuiltinOperator( - (BuiltinOperator)registration.builtin_code) + static_cast(registration.builtin_code)) << "\n"; } @@ -190,13 +189,13 @@ void RunInference(Settings* s) { if (s->profiling) profiler->StartProfiling(); struct timeval start_time, stop_time; - gettimeofday(&start_time, NULL); + gettimeofday(&start_time, nullptr); for (int i = 0; i < s->loop_count; i++) { if (interpreter->Invoke() != kTfLiteOk) { LOG(FATAL) << "Failed to invoke tflite!\n"; } } - gettimeofday(&stop_time, NULL); + gettimeofday(&stop_time, nullptr); LOG(INFO) << "invoked \n"; LOG(INFO) << "average time: " << (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000) @@ -271,17 +270,17 @@ int Main(int argc, char** argv) { int c; while (1) { static struct option long_options[] = { - {"accelerated", required_argument, 0, 'a'}, - {"count", required_argument, 0, 'c'}, - {"verbose", required_argument, 0, 'v'}, - {"image", required_argument, 0, 'i'}, - {"labels", required_argument, 0, 'l'}, - {"tflite_model", required_argument, 0, 'm'}, - {"profiling", required_argument, 0, 'p'}, - {"threads", required_argument, 0, 't'}, - {"input_mean", required_argument, 0, 'b'}, - {"input_std", required_argument, 0, 's'}, - {0, 0, 0, 0}}; + {"accelerated", required_argument, nullptr, 'a'}, + {"count", required_argument, nullptr, 'c'}, + {"verbose", required_argument, nullptr, 'v'}, + {"image", required_argument, nullptr, 'i'}, + {"labels", required_argument, nullptr, 'l'}, + {"tflite_model", required_argument, nullptr, 'm'}, + {"profiling", required_argument, nullptr, 'p'}, + {"threads", required_argument, nullptr, 't'}, + {"input_mean", required_argument, nullptr, 'b'}, + {"input_std", required_argument, nullptr, 's'}, + {nullptr, 0, nullptr, 0}}; /* getopt_long stores the option index here. */ int option_index = 0; @@ -294,15 +293,14 @@ int Main(int argc, char** argv) { switch (c) { case 'a': - s.accel = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); + s.accel = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'b': - s.input_mean = strtod(optarg, NULL); + s.input_mean = strtod(optarg, nullptr); break; case 'c': - s.loop_count = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); + s.loop_count = + strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'i': s.input_bmp_name = optarg; @@ -314,19 +312,19 @@ int Main(int argc, char** argv) { s.model_name = optarg; break; case 'p': - s.profiling = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); + s.profiling = + strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 's': - s.input_std = strtod(optarg, NULL); + s.input_std = strtod(optarg, nullptr); break; case 't': s.number_of_threads = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); + optarg, nullptr, 10); break; case 'v': - s.verbose = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); + s.verbose = + strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'h': case '?': diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md index d7cc854ebac08e79d346df0aca6e1fa56b490156..972e57f73e82961ebc5e341dd7a41bc00acc5d21 100644 --- a/tensorflow/contrib/lite/g3doc/custom_operators.md +++ b/tensorflow/contrib/lite/g3doc/custom_operators.md @@ -39,7 +39,7 @@ TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); int num_dims = NumDimensions(input); @@ -54,7 +54,7 @@ TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { using namespace tflite; - TfLiteTensor* input = GetInput(context, node,0); + const TfLiteTensor* input = GetInput(context, node,0); TfLiteTensor* output = GetOutput(context, node,0); float* input_data = input->data.f; diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index 0051ee84ec38f8acd77804c3d5a005001a44d258..d8c46e633151cba94ff3d2a3c8b0ab5c230f245e 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -132,9 +132,7 @@ TensorFlow operation not listed above are likely unsupported. Notably, the following common ops are not supported at the moment: * [tf.depth_to_space](https://www.tensorflow.org/api_docs/python/tf/depth_to_space) -* [tf.gather](https://www.tensorflow.org/api_docs/python/tf/gather) * [tf.image.resize_bilinear](https://www.tensorflow.org/api_docs/python/tf/image/resize_bilinear) -* [tf.slice](https://www.tensorflow.org/api_docs/python/tf/slice) * [tf.tanh](https://www.tensorflow.org/api_docs/python/tf/tanh) ## TensorFlow Lite Operations @@ -222,6 +220,23 @@ Options { } ``` +**CONV_2D_TRANSPOSE** + +``` +Inputs { + 0: output_shape + 1: filter + 2: 4D tensor +} +Outputs { + 0: the transpose (gradient) of conv2d +} +Options { + padding: SAME|VALID + stride_w,stride_h: stride of the filter window +} +``` + **DEPTHWISE_CONV_2D** ``` @@ -281,6 +296,45 @@ Options { } ``` +**GATHER** + +``` +Inputs { + 0: params tensor + 1: indices tensor + 2: axis tensor (optional) +} +Outputs { + 0: a tensor with same type as the params tensor. +} +``` + +**GREATER** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is + greater than the corresponding element of the second tensor. +} +``` + +**GREATER_EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is + greater than or equal to the corresponding element of the second tensor. +} +``` + **L2_NORMALIZATION** ``` @@ -325,6 +379,19 @@ Outputs { } ``` +**LESS_EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is less + than or equal to the corresponding element of the second tensor. +} +``` + **LOCAL_RESPONSE_NORMALIZATION** ``` @@ -484,6 +551,19 @@ Options { } ``` +**SLICE** + +``` +Inputs { + 0: tensor + 1: 1D tensor + 2: 1D tensor +} +Outputs { + 0: slice of the input tensor of the given size from the given begin index. +} +``` + **SOFTMAX** ``` @@ -569,7 +649,7 @@ Outputs { 0: slice of the input tensor of the given size } Options { - begin_mask: mask for begin indicies + begin_mask: mask for begin indices end_mask: mask for end indices shrink_axis_mask: mask that indicates which dimensions to remove } @@ -584,7 +664,7 @@ Inputs { } Outputs { 0: k largest element along each last dimensional slice - 1: indicies of values within the last dimension of the input ensor + 1: indices of values within the last dimension of the input ensor } ``` @@ -600,6 +680,20 @@ Outputs { } ``` +**SELECT** + +``` +Inputs { + 0: tensor + 1: tensor + 2: tensor +} +Outputs { + 0: tensor that contains the elementwise values of 'tensor 1' if the + corresponding value of 'tensor 0' is true or the value of 'tensor 2' if false. +} +``` + And these are TensorFlow Lite operations that are present but not ready for custom models yet: diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 1074f64263b5d7d64cac57d011c5df1779abffc1..7315d8360680ca0d3c405dc80b593762275815ee 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -201,7 +201,7 @@ class Interpreter { // Overrides execution plan. This bounds checks indices sent in. TfLiteStatus SetExecutionPlan(const std::vector& new_plan); - // Get a tensor data structure. + // Get a mutable tensor data structure. // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this // read/write access to structure TfLiteTensor* tensor(int tensor_index) { @@ -210,9 +210,14 @@ class Interpreter { return &context_.tensors[tensor_index]; } + // Get an immutable tensor data structure. + const TfLiteTensor* tensor(int tensor_index) const { + if (tensor_index >= context_.tensors_size || tensor_index < 0) + return nullptr; + return &context_.tensors[tensor_index]; + } + // Get a pointer to an operation and registration data structure if in bounds. - // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this - // read/write access to structure const std::pair* node_and_registration( int node_index) const { if (node_index >= nodes_and_registration_.size() || node_index < 0) @@ -220,7 +225,8 @@ class Interpreter { return &nodes_and_registration_[node_index]; } - // Perform a checked cast to the appropriate tensor type. + // Perform a checked cast to the appropriate tensor type (mutable pointer + // version). template T* typed_tensor(int tensor_index) { if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) { @@ -231,20 +237,46 @@ class Interpreter { return nullptr; } - // Return a pointer into the data of a given input tensor. The given index - // must be between 0 and inputs().size(). + // Perform a checked cast to the appropriate tensor type (immutable pointer + // version). + template + const T* typed_tensor(int tensor_index) const { + if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) { + if (tensor_ptr->type == typeToTfLiteType()) { + return reinterpret_cast(tensor_ptr->data.raw); + } + } + return nullptr; + } + + // Return a mutable pointer into the data of a given input tensor. The given + // index must be between 0 and inputs().size(). template T* typed_input_tensor(int index) { return typed_tensor(inputs_[index]); } - // Return a pointer into the data of a given output tensor. The given index - // must be between 0 and outputs().size(). + // Return an immutable pointer into the data of a given input tensor. The + // given index must be between 0 and inputs().size(). + template + const T* typed_input_tensor(int index) const { + return typed_tensor(inputs_[index]); + } + + // Return a mutable pointer into the data of a given output tensor. The given + // index must be between 0 and outputs().size(). template T* typed_output_tensor(int index) { return typed_tensor(outputs_[index]); } + // Return an immutable pointer into the data of a given output tensor. The + // given index must be between 0 and outputs().size(). + template + const T* typed_output_tensor(int index) const { + return typed_tensor(outputs_[index]); + } + // Change the dimensionality of a given tensor. Note, this is only acceptable // for tensor indices that are inputs. // Returns status of failure or success. diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD index 1dda55b8edf8f85293c473b51b8a19066bac5f73..593af81a18a1e20a41dcc8d9bb3a1d815876e294 100644 --- a/tensorflow/contrib/lite/java/BUILD +++ b/tensorflow/contrib/lite/java/BUILD @@ -1,7 +1,9 @@ # Description: # TensorFlow Lite Java API. -package(default_visibility = ["//visibility:private"]) +package(default_visibility = [ + "//tensorflow/contrib/lite/java/ovic:__pkg__", +]) licenses(["notice"]) # Apache 2.0 @@ -46,23 +48,6 @@ android_library( ], ) -java_library( - name = "ovicbenchmarkerlib", - srcs = [ - "ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java", - "ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", - ], - javacopts = JAVACOPTS, - visibility = ["//visibility:public"], - deps = [ - ":libtensorflowlite_jni.so", - ":tensorflowlite_java", - "//tensorflow/contrib/lite/java/src/main/native", - "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", - "@org_checkerframework_qual", - ], -) - java_library( name = "tensorflowlitelib", srcs = glob( @@ -165,28 +150,6 @@ java_test( ], ) -java_test( - name = "OvicClassifierTest", - size = "medium", - srcs = ["ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"], - data = [ - "ovic/src/testdata/float_model.lite", - "ovic/src/testdata/labels.txt", - "ovic/src/testdata/low_res_model.lite", - "ovic/src/testdata/quantized_model.lite", - "ovic/src/testdata/test_image_128.jpg", - "ovic/src/testdata/test_image_224.jpg", - ], - javacopts = JAVACOPTS, - test_class = "org.tensorflow.ovic.OvicClassifierTest", - visibility = ["//visibility:public"], - deps = [ - ":ovicbenchmarkerlib", - "@com_google_truth", - "@junit", - ], -) - filegroup( name = "libtensorflowlite_jni", srcs = select({ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml index ba63dce5d9a7192a2c3c4c5561333d39a3ecc024..95b6b7016f2818127a89d2e9212aa231a5ec24b9 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml @@ -31,6 +31,7 @@ android:theme="@style/MaterialTheme"> diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml index 20f520814d7154764932638c5e9dddc32639b677..ef8a9e08450d72e392815756606f5ef8301cdd58 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml @@ -13,51 +13,55 @@ See the License for the specific language governing permissions and limitations under the License. --> - + android:layout_height="match_parent" + android:background="#bb7700" + android:orientation="horizontal"> + + + + + + - + + + - - - - - - - - - - - - - - + android:paddingTop="20dp" + android:textColor="#FFF" + android:textSize="20sp"/> + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml index 72a229ecdb19f5309994e994d82e0b5b5ed617a2..ddb099a950c2f83d7b2867f8f35d96885229536d 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml @@ -28,7 +28,7 @@ + - + android:id="@+id/bottom_info_view" + android:layout_marginBottom="10dp" + android:layout_height="50dp"> + + + android:layout_marginLeft="10dp" + android:background="#0000000f" + android:textColor="@android:color/white" /> + + - - diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml index d12435d5abda45917b8a4f12c4b3179997eae689..e567009a424ed77384bee193c47d4f4d253f5767 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml @@ -15,101 +15,80 @@ --> + android:layout_height="match_parent" + android:background="#bb7700"> - + android:layout_weight="1" /> - - - - - + android:layout_alignParentTop="false" + android:background="#bb7700" + android:orientation="vertical" + android:weightSum="100"> + + + + - - + - + android:layout_height="match_parent" + android:textColor="@android:color/white" + android:textAlignment="center" + android:gravity="center" + android:text="@string/threads" /> - - - - - - - + android:layout_marginLeft="10dp" + android:background="#0000000f" + android:textColor="@android:color/white" /> + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml index 0a71dbd0e8010f5e3a176de1f7e8321331289f7c..7af8f3a98c6319da7723928ce61802ed4c5497ec 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml @@ -16,7 +16,7 @@ --> - TfLiteCameraDemo + TfLite Camera Demo + Threads: diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml index 3f3bdfb49480e779c108cd15da854ae82a118d52..1752b3b5f97e288d8b59106dfece1d84fe21d0ba 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml @@ -14,5 +14,10 @@ limitations under the License. --> - + diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..362d93636f72205ddcda6d97fa9fae376ff211f1 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/BUILD @@ -0,0 +1,68 @@ +# Description: +# OVIC Benchmarker Java API. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") + +java_test( + name = "OvicClassifierTest", + size = "medium", + srcs = ["src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"], + data = [ + "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", + "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", + ], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.ovic.OvicClassifierTest", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib_java", + "@com_google_truth", + "@junit", + ], +) + +java_binary( + name = "ovic_validator", + srcs = ["src/main/java/org/tensorflow/ovic/OvicValidator.java"], + data = [ + "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", + ], + main_class = "org.tensorflow.ovic.OvicValidator", + deps = [ + "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib_java", + ], +) + +android_library( + name = "ovicbenchmarkerlib", + srcs = [ + "src/main/java/org/tensorflow/ovic/OvicClassifier.java", + "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", + ], + manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml", + deps = [ + "//tensorflow/contrib/lite/java:tensorflowlite", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "@org_checkerframework_qual", + ], +) + +java_library( + name = "ovicbenchmarkerlib_java", + srcs = [ + "src/main/java/org/tensorflow/ovic/OvicClassifier.java", + "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", + ], + javacopts = JAVACOPTS, + deps = [ + "//tensorflow/contrib/lite/java:libtensorflowlite_jni.so", + "//tensorflow/contrib/lite/java:tensorflowlite_java", + "//tensorflow/contrib/lite/java/src/main/native", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "@org_checkerframework_qual", + ], +) diff --git a/tensorflow/contrib/lite/java/ovic/README.md b/tensorflow/contrib/lite/java/ovic/README.md index 76c33838bfe5b8596d78cae7d022c51d2a379e76..26349347faebac135ae555e0c5d8219046ab1c29 100644 --- a/tensorflow/contrib/lite/java/ovic/README.md +++ b/tensorflow/contrib/lite/java/ovic/README.md @@ -2,11 +2,11 @@ This folder contains building code for track one of the [Low Power ImageNet Recognition Challenge workshop at CVPR 2018.](https://rebootingcomputing.ieee.org/home/sitemap/14-lpirc/80-low-power-image-recognition-challenge-lpirc-2018) -## Pre-requesits +## Pre-requisite Follow the steps [here](https://www.tensorflow.org/mobile/tflite/demo_android) to install Tensorflow, Bazel, and the Android NDK and SDK. -## To test the benchmarker: +## Test the benchmarker: The testing utilities helps the developers (you) to make sure that your submissions in TfLite format will be processed as expected in the competition's benchmarking system. @@ -37,47 +37,122 @@ unzip -j /tmp/ovic.zip -d tensorflow/contrib/lite/java/ovic/src/testdata/ You can run test with Bazel as below. This helps to ensure that the installation is correct. ```sh -bazel test --cxxopt=--std=c++11 //tensorflow/contrib/lite/java:OvicClassifierTest --test_output=all +bazel test --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/ovic:OvicClassifierTest --cxxopt=-Wno-all --test_output=all ``` ### Test your submissions -Once you have a submission that follows the instructions from the [competition site](https://rebootingcomputing.ieee.org/home/sitemap/14-lpirc/80-low-power-image-recognition-challenge-lpirc-2018), you can verify it as below. +Once you have a submission that follows the instructions from the [competition site](https://rebootingcomputing.ieee.org/home/sitemap/14-lpirc/80-low-power-image-recognition-challenge-lpirc-2018), you can verify it in two ways: -* Move your submission to the testdata folder: +#### Validate using randomly generated images -Let say the submission file is located at `/tmp/my_model.lite`, then +You can call the validator binary below to verify that your model fits the format requirements. This often helps you to catch size mismatches (e.g. output should be [1, 1001] instead of [1,1,1,1001]). Let say the submission file is located at `/path/to/my_model.lite`, then call: ```sh -cp /tmp/my_model.lite tensorflow/contrib/lite/java/ovic/src/testdata/ +bazel build --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/ovic:ovic_validator --cxxopt=-Wno-all +bazel-bin/tensorflow/contrib/lite/java/ovic/ovic_validator /path/to/my_model.lite +``` + +Successful validation should print the following message to terminal: + +``` +Successfully validated /path/to/my_model.lite. + +``` + +#### Test that the model produces sensible outcomes + +You can go a step further to verify that the model produces results as expected. This helps you catch bugs during TOCO conversion (e.g. using the wrong mean and std values). + +* Move your submission to the testdata folder: + +```sh +cp /path/to/my_model.lite tensorflow/contrib/lite/java/ovic/src/testdata/ ``` * Resize the test image to the resolutions that are expected by your submission: The test images can be found at `tensorflow/contrib/lite/java/ovic/src/testdata/test_image_*.jpg`. You may reuse these images if your image resolutions are 128x128 or 224x224. -* Add your model and test image to the BUILD rule: +* Add your model and test image to the BUILD rule at `tensorflow/contrib/lite/java/ovic/src/testdata/BUILD`: ```JSON -java_test( - name = "OvicClassifierTest", - size = "medium", - srcs = ["ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"], - data = [ - "ovic/src/testdata/float_model.lite", - "ovic/src/testdata/labels.txt", - "ovic/src/testdata/low_res_model.lite", - "ovic/src/testdata/quantized_model.lite", - "ovic/src/testdata/test_image_128.jpg", - "ovic/src/testdata/test_image_224.jpg", - "ovic/src/testdata/my_model.lite", # <--- Your submission. - "ovic/src/testdata/my_test_image.jpg", # <--- Your test image. - ], - ... +filegroup( + name = "ovic_testdata", + srcs = [ + "@tflite_ovic_testdata//:float_model.lite", + "@tflite_ovic_testdata//:low_res_model.lite", + "@tflite_ovic_testdata//:quantized_model.lite", + "@tflite_ovic_testdata//:test_image_128.jpg", + "@tflite_ovic_testdata//:test_image_224.jpg" + "my_model.lite", # <--- Your submission. + "my_test_image.jpg", # <--- Your test image. + ], + ... ``` * Modify `OvicClassifierTest.java` to test your model. -Change `TEST_IMAGE_PATH` to `testdata/my_test_image.jpg`. If your model runs inference in floating point, change `FLOAT_MODEL_PATH` to `testdata/my_model.lite`. If your model runs [quantized inference](https://www.tensorflow.org/performance/quantization), change `QUANTIZED_MODEL_PATH` to `testdata/my_model.lite`. +Change `TEST_IMAGE_PATH` to `my_test_image.jpg`. Change either `FLOAT_MODEL_PATH` or `QUANTIZED_MODEL_PATH` to `my_model.lite` depending on whether your model runs inference in float or [8-bit](https://www.tensorflow.org/performance/quantization). Now you can run the bazel tests to catch any runtime issues with the submission. + +Note: Please make sure that your submission passes the test. If a submission fails to pass the test it will not be processed by the submission server. + +## Measure on-device latency + +We provide two ways to measure the on-device latency of your submission. The first is through our competition server, which is reliable and repeatable, but is limited to a few trials per day. The second is through the benchmarker Apk, which requires a device and may not be as accurate as the server, but has a fast turn-around and no access limitations. We recommend that the participants use the benchmarker apk for early development, and reserve the competition server for evaluating promising submissions. + +### Running the benchmarker app + +Make sure that you have followed instructions in [Test your submissions](#test-your-submissions) to add your model to the testdata folder and to the corresponding build rules. + +Modify `tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java`: + +* Add your model to the benchmarker apk by changing `MODEL_PATH` and `TEST_IMAGE_PATH` below to your submission and test image. + +``` + private static final String TEST_IMAGE_PATH = "my_test_image.jpg"; + private static final String MODEL_PATH = "my_model.lite"; +``` + +* Adjust the benchmark parameters when needed: + +You can chnage the length of each experiment, and the processor affinity below. `BIG_CORE_MASK` is an integer whose binary encoding represents the set of used cores. This number is phone-specific. For example, Pixel 2 has 8 cores: the 4 little cores are represented by the 4 less significant bits, and the 4 big cores by the 4 more significant bits. Therefore a mask value of 16, or in binary `00010000`, represents using only the first big core. The mask 32, or in binary `00100000` uses the second big core and should deliver identical results as the mask 16 because the big cores are interchangeable. + +``` + /** Wall time for each benchmarking experiment. */ + private static final double WALL_TIME = 3000; + /** Maximum number of iterations in each benchmarking experiment. */ + private static final int MAX_ITERATIONS = 100; + /** Mask for binding to a single big core. Pixel 1 (4), Pixel 2 (16). */ + private static final int BIG_CORE_MASK = 16; +``` + +Note: You'll need ROOT access to the phone to change processor affinity. + +* Build and install the app. + +``` +bazel build -c opt --cxxopt=--std=c++11 --cxxopt=-Wno-all //tensorflow/contrib/lite/java/ovic/demo/app:ovic_benchmarker_binary +adb install -r bazel-bin/tensorflow/contrib/lite/java/ovic/demo/app/ovic_benchmarker_binary.apk +``` + +Start the app and click the `Start` button in dark green. The button should turn bright green, signaling that the experiment is running. The benchmarking results will be displayed after about the `WALL_TIME` you specified above. For example: + +``` +my_model.lite: Average latency=158.6ms after 20 runs. +``` + +### Sample latencies + +Note: the benchmarking results can be quite different depending on the background processes running on the phone. A few things that help stabilize the app's readings are placing the phone on a cooling plate, restarting the phone, and shutting down internet access. + +| Model | Pixel 1 latency (ms) | Pixel 2 latency (ms) | +| -------------------- |:---------------------:| --------------------:| +| float_model.lite | 120 | 155 | +| quantized_model.lite | 85 | 74 | +| low_res_model.lite | 4.2 | 4.0 | + +Since Pixel 2 has excellent support for 8-bit quantized models, we strongly recommend you to check out the [quantization training tutorial](https://www.tensorflow.org/performance/quantization). + diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/AndroidManifest.xml b/tensorflow/contrib/lite/java/ovic/demo/app/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..55f2961fd717bdeebf5f3f1e66bb537f53cbe4e0 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/AndroidManifest.xml @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..83974f4b337baedebaf9c9ffc0a03501418a3e36 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD @@ -0,0 +1,29 @@ +# Sample app for OVIC benchmarking. +licenses(["notice"]) # Apache 2.0 + +android_binary( + name = "ovic_benchmarker_binary", + srcs = [ + "OvicBenchmarker.java", + "OvicBenchmarkerActivity.java", + ], + assets = [ + "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", + "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", + ], + assets_dir = "", + custom_package = "ovic.demo.app", + manifest = "AndroidManifest.xml", + nocompress_extensions = [ + ".lite", + ".tflite", + ], + resource_files = glob(["res/**"]), + tags = ["manual"], + deps = [ + "//tensorflow/contrib/lite/java:tensorflowlite", + "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib", + "@androidsdk//com.android.support:support-v13-25.2.0", + "@androidsdk//com.android.support:support-v4-25.2.0", + ], +) diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java similarity index 97% rename from tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java rename to tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java index d0102883e6b41f5c33a0061c5fd53b5f69b8ab54..113ab74a20dabc7e283804348509702b7f412917 100644 --- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java +++ b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java @@ -1,4 +1,4 @@ -/*Copyright 2018 Google LLC +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -package org.tensorflow.ovic; +package ovic.demo.app; import android.graphics.Bitmap; import android.os.SystemClock; @@ -22,6 +22,8 @@ import java.io.InputStream; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.MappedByteBuffer; +import org.tensorflow.ovic.OvicClassifier; +import org.tensorflow.ovic.OvicSingleImageResult; /** * Class that benchmarks image classifier models. diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java new file mode 100644 index 0000000000000000000000000000000000000000..59457c308ad7caa17c52563f6a70df79e8a17914 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java @@ -0,0 +1,247 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package ovic.demo.app; + +import android.app.Activity; +import android.content.res.AssetFileDescriptor; +import android.content.res.AssetManager; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.os.Bundle; +import android.os.Process; +import android.os.SystemClock; +import android.util.Log; +import android.view.View; +import android.widget.TextView; +import java.io.BufferedReader; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileReader; +import java.io.IOException; +import java.io.InputStream; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.text.DecimalFormat; +import org.tensorflow.ovic.OvicSingleImageResult; + +/** Class that benchmark image classifier models. */ +public class OvicBenchmarkerActivity extends Activity { + /** Tag for the {@link Log}. */ + private static final String TAG = "OvicBenchmarkerActivity"; + + /** Name of the label file stored in Assets. */ + private static final String LABEL_PATH = "labels.txt"; + + private static final String TEST_IMAGE_PATH = "test_image_224.jpg"; + private static final String MODEL_PATH = "float_model.lite"; + /** + * Each bottom press will launch a benchmarking experiment. The experiment stops when either the + * total native latency reaches WALL_TIME or the number of iterations reaches MAX_ITERATIONS, + * whichever comes first. + */ + /** Wall time for each benchmarking experiment. */ + private static final double WALL_TIME = 3000; + /** Maximum number of iterations in each benchmarking experiment. */ + private static final int MAX_ITERATIONS = 100; + /** Mask for binding to a single big core. Pixel 1 (4), Pixel 2 (16). */ + private static final int BIG_CORE_MASK = 16; + /** Amount of time in milliseconds to wait for affinity to set. */ + private static final int WAIT_TIME_FOR_AFFINITY = 1000; + + /* The model to be benchmarked. */ + private MappedByteBuffer model = null; + private InputStream labelInputStream = null; + private OvicBenchmarker benchmarker; + /** Inference result of each iteration. */ + OvicSingleImageResult iterResult = null; + + private TextView textView = null; + // private Button startButton = null; + private static final DecimalFormat df2 = new DecimalFormat(".##"); + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + + // TextView used to display the progress, for information purposes only. + textView = (TextView) findViewById(R.id.textView); + } + + private Bitmap loadTestBitmap() throws IOException { + InputStream imageStream = getAssets().open(TEST_IMAGE_PATH); + return BitmapFactory.decodeStream(imageStream); + } + + public void initializeTest() throws IOException { + Log.i(TAG, "Initializing benchmarker."); + benchmarker = new OvicBenchmarker(WALL_TIME); + AssetManager am = getAssets(); + AssetFileDescriptor fileDescriptor = am.openFd(MODEL_PATH); + FileInputStream modelInputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); + FileChannel fileChannel = modelInputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + model = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + labelInputStream = am.open(LABEL_PATH); + } + + public Boolean doTestIteration() throws IOException, InterruptedException { + if (benchmarker == null) { + throw new RuntimeException("Benchmarker has not been initialized."); + } + if (benchmarker.shouldStop()) { + return false; + } + if (!benchmarker.readyToTest()) { + Log.i(TAG, "getting ready to test."); + benchmarker.getReadyToTest(labelInputStream, model); + if (!benchmarker.readyToTest()) { + throw new RuntimeException("Failed to get the benchmarker ready."); + } + } + Log.i(TAG, "Going to do test iter."); + // Start testing. + Bitmap testImageBitmap = loadTestBitmap(); + iterResult = benchmarker.doTestIteration(testImageBitmap); + testImageBitmap.recycle(); + if (iterResult == null) { + throw new RuntimeException("Inference failed to produce a result."); + } + Log.i(TAG, iterResult.toString()); + return true; + } + + public void startPressed(View view) throws IOException { + Log.i(TAG, "Start pressed"); + try { + initializeTest(); + } catch (IOException e) { + Log.e(TAG, "Can't initialize benchmarker.", e); + throw e; + } + String displayText = ""; + try { + setProcessorAffinity(BIG_CORE_MASK); + } catch (IOException e) { + Log.e(TAG, e.getMessage()); + displayText = e.getMessage() + "\n"; + } + Log.i(TAG, "Successfully initialized benchmarker."); + int testIter = 0; + Boolean iterSuccess = false; + double totalLatency = 0.0f; + while (testIter < MAX_ITERATIONS) { + try { + iterSuccess = doTestIteration(); + } catch (IOException e) { + Log.e(TAG, "Error during iteration " + testIter); + throw e; + } catch (InterruptedException e) { + Log.e(TAG, "Interrupted at iteration " + testIter); + } + if (!iterSuccess) { + break; + } + testIter++; + totalLatency += (double) iterResult.latency; + } + ; + Log.i(TAG, "Benchmarking finished"); + + if (textView != null) { + if (testIter > 0) { + textView.setText( + displayText + + MODEL_PATH + + ": Average latency=" + + df2.format(totalLatency / testIter) + + "ms after " + + testIter + + " runs."); + } else { + textView.setText("Benchmarker failed to run on more than one images."); + } + } + } + + private static void setProcessorAffinity(int mask) throws IOException { + int myPid = Process.myPid(); + Log.i(TAG, String.format("Setting processor affinity to 0x%02x", mask)); + + String command = String.format("taskset -a -p %x %d", mask, myPid); + try { + Runtime.getRuntime().exec(command).waitFor(); + } catch (InterruptedException e) { + throw new IOException("Interrupted: " + e); + } + + // Make sure set took effect - try for a second to confirm the change took. If not then fail. + long startTimeMs = SystemClock.elapsedRealtime(); + while (true) { + int readBackMask = readCpusAllowedMask(); + if (readBackMask == mask) { + Log.i(TAG, String.format("Successfully set affinity to 0x%02x", mask)); + return; + } + if (SystemClock.elapsedRealtime() > startTimeMs + WAIT_TIME_FOR_AFFINITY) { + throw new IOException( + String.format( + "Core-binding failed: affinity set to 0x%02x but read back as 0x%02x\n" + + "please root device.", + mask, readBackMask)); + } + + try { + Thread.sleep(50); + } catch (InterruptedException e) { + // Ignore sleep interrupted, will sleep again and compare is final cross-check. + } + } + } + + public static int readCpusAllowedMask() throws IOException { + // Determine how many CPUs there are total + final String pathname = "/proc/self/status"; + final String resultPrefix = "Cpus_allowed:"; + File file = new File(pathname); + String line = ""; + String allowedCPU = ""; + Integer allowedMask = null; + BufferedReader bufReader = null; + try { + bufReader = new BufferedReader(new FileReader(file)); + while ((line = bufReader.readLine()) != null) { + if (line.startsWith(resultPrefix)) { + allowedMask = Integer.valueOf(line.substring(resultPrefix.length()).trim(), 16); + allowedCPU = bufReader.readLine(); + break; + } + } + } catch (RuntimeException e) { + throw new IOException( + "Invalid number in " + pathname + " line: \"" + line + "\": " + e.getMessage()); + } finally { + if (bufReader != null) { + bufReader.close(); + } + } + if (allowedMask == null) { + throw new IOException(pathname + " missing " + resultPrefix + " line"); + } + Log.i(TAG, allowedCPU); + return allowedMask; + } +} diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..c5d19bad89a93988a6830a17fe2fb4a60e2fb00f --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle @@ -0,0 +1,58 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 26 + buildToolsVersion "26.0.1" + defaultConfig { + applicationId "android.example.com.ovicbenchmarker" + minSdkVersion 15 + targetSdkVersion 26 + versionCode 1 + versionName "1.0" + testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + + // Remove this block. + jackOptions { + enabled true + } + } + lintOptions { + abortOnError false + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + aaptOptions { + noCompress "lite", "tflite" + } + + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +repositories { + maven { + url 'https://google.bintray.com/tensorflow' + } +} + +dependencies { + compile fileTree(dir: 'libs', include: ['*.jar']) + androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + exclude group: 'com.android.support', module: 'support-annotations' + }) + compile 'com.android.support:appcompat-v7:25.2.0' + compile 'com.android.support.constraint:constraint-layout:1.0.2' + compile 'com.android.support:design:25.2.0' + compile 'com.android.support:support-annotations:25.3.1' + compile 'com.android.support:support-v13:25.2.0' + + compile 'org.tensorflow:tensorflow-lite:+' + + testCompile 'junit:junit:4.12' +} diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..715d1b6d69c0f4dc4d1ae58c8262c22856b20f43 Binary files /dev/null and b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..9beff0885fd4c8c65ea30c99c838370dcd745f3c Binary files /dev/null and b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable/start_button_color.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable/start_button_color.xml new file mode 100644 index 0000000000000000000000000000000000000000..93f5c6a016b499f1bd7bacde9b4b94a4ee9fdb6b --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable/start_button_color.xml @@ -0,0 +1,39 @@ + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml new file mode 100644 index 0000000000000000000000000000000000000000..e9d83bae543ae62ba8749c4c91b36b20bf09a176 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml @@ -0,0 +1,54 @@ + + + + + + +